PyTorch 在深度学习领域中的应用日趋广泛, 得益于它独到的设计. 无论是数据的并行处理还是动态计算图, 一切都为 Python 做出了很多简化. 很多论文都选择使用 PyTorch 去实现也证明了它在训练方面的效率以及易用性.
在 PyTorch 领域, 尽管部署一个模型有很多选择, 可为 Java 开发人员准备的选项却屈指可数.
在过去, 用户可以用 PyTorch C++ 写 JNI (Java Native Interface) 来实现这个过程. 最近, PyTorch 1.4 也发布了试验性的 Java 前端.
可是这两种解决方案都没有办法能让 Java 开发者很好的使用: 用户需要从易于使用和易于维护中二选一.
针对于这个问题, 亚马逊云服务 (AWS)开源了 Deep Java Library (DJL), 一个为 Java 开发者设计的深度学习库. 它兼顾了易用性和可维护性, 一切运行效率以及内存管理问题都得到了很好的处理.
DJL 使用起来异常简单. 只需几行代码, 用户就可以轻松部署深度学习模型用作推理. 那么我们就开始上手用 DJL 部署一个 PyTorch 模型吧.
前期准备
用户可以轻松使用 maven 或者 gradle 等 Java 常用配置管理包来引用 DJL. 下面是一个示例:
- plugins {
- id 'java'
- }
- repositories {
- jcenter()
- }
- dependencies {
- implementation "ai.djl:api:0.4.0"
- implementation "ai.djl:repository:0.4.0"
- runtimeOnly "ai.djl.pytorch:pytorch-model-zoo:0.4.0"
- runtimeOnly "ai.djl.pytorch:pytorch-native-auto:1.4.0"
- }
然后只需 gradle build, 基本配置就大功告成了.
开始部署模型
我们用到的目标检测模型来源于 NVIDIA 在 torchhub 发布的预训练模型. 我们用下面这张图来推理几个可以识别的物体(狗, 自行车以及皮卡).
可以通过下面的代码来实现推理的过程:
- public static void main(String[] args) throws IOException, ModelException, TranslateException {
- String url = "https://github.com/awslabs/djl/raw/master/examples/src/test/resources/dog_bike_car.jpg";
- BufferedImage img = BufferedImageUtils.fromUrl(url);
- Criteria<BufferedImage, DetectedObjects> criteria =
- Criteria.builder()
- .optApplication(Application.CV.OBJECT_DETECTION)
- .setTypes(BufferedImage.class, DetectedObjects.class)
- .optFilter("backbone", "resnet50")
- .optProgress(new ProgressBar())
- .build();
- try (ZooModel<BufferedImage, DetectedObjects> model = ModelZoo.loadModel(criteria)) {
- try (Predictor<BufferedImage, DetectedObjects> predictor = model.newPredictor()) {
- DetectedObjects detection = predictor.predict(img);
- System.out.println(detection);
- }
- }
- }
然后, 就结束了. 相比于其他解决方案动辄上百行的代码, DJL 把所有过程简化到了不到 30 行完成. 那么我们看看输出的结果:
- [
- class: "dog", probability: 0.96709, bounds: [x=0.165, y=0.348, width=0.249, height=0.539]
- class: "bicycle", probability: 0.66796, bounds: [x=0.152, y=0.244, width=0.574, height=0.562]
- class: "truck", probability: 0.64912, bounds: [x=0.609, y=0.132, width=0.284, height=0.166]
- ]
你也可以用我们目标检测图形化 API 来看一下实际的检测效果:
你也许会说, 这些代码都包装的过于厉害, 真正的小白该如何上手呢?
让我们仔细的看一下刚才的那段代码:
- // 读取一张图片
- String url = "https://github.com/awslabs/djl/raw/master/examples/src/test/resources/dog_bike_car.jpg";
- BufferedImage img = BufferedImageUtils.fromUrl(url);
- // 创建一个模型的寻找标准
- Criteria<BufferedImage, DetectedObjects> criteria =
- Criteria.builder()
- // 设置应用类型: 目标检测
- .optApplication(Application.CV.OBJECT_DETECTION)
- // 确定输入输出类型 (使用默认的图片处理工具)
- .setTypes(BufferedImage.class, DetectedObjects.class)
- // 模型的过滤条件
- .optFilter("backbone", "resnet50")
- .optProgress(new ProgressBar())
- .build();
- // 创建一个模型对象
- try (ZooModel<BufferedImage, DetectedObjects> model = ModelZoo.loadModel(criteria)) {
- // 创建一个推理对象
- try (Predictor<BufferedImage, DetectedObjects> predictor = model.newPredictor()) {
- // 推理
- DetectedObjects detection = predictor.predict(img);
- System.out.println(detection);
- }
- }
这样是不是清楚了很多? DJL 建立了一个模型库 (ModelZoo) 的概念, 引入了来自于 GluonCV, TorchHub, keras 预训练模型, huggingface 自然语言处理模型等 70 多个模型. 所有的模型都可以一键导入, 用户只需要使用默认或者自己写的输入输出工具就可以实现轻松的推理. 我们还在不断的添加各种预训练模型.
了解 DJL
DJL 是亚马逊云服务在 2019 年 re:Invent 大会推出的专为 Java 开发者量身定制的深度学习框架, 现已运行在亚马逊数以百万的推理任务中.
如果要总结 DJL 的主要特色, 那么就是如下三点:
DJL 不设限制于后端引擎: 用户可以轻松的使用 MXNet, PyTorch, TensorFlow 和 fastText 来在 Java 上做模型训练和推理.
DJL 的算子设计无限趋近于 numpy: 它的使用体验上和 numpy 基本是无缝的, 切换引擎也不会造成结果改变.
DJL 优秀的内存管理以及效率机制: DJL 拥有自己的资源回收机制, 100 个小时连续推理也不会内存溢出.
James Gosling (Java 创始人) 在使用后给出了赞誉:
对于 PyTorch 的支持
来源: http://news.51cto.com/art/202007/622003.htm