使用 torch 的 hub 模块载入模型, 输入数据进行模型的结果输出, 对输出的结果做可视化处理
- ## GitHub https://github.com/pytorch/hub
- import torch
- model = torch.hub.load('pytorch/vision:v0.4.2', 'deeplabv3_resnet101', pretrained=True)
- model.eval()
- print(torch.hub.list('pytorch/vision:v0.4.2'))
- # 数据载入, 获得图片
- import urllib
- url, filename = ("https://github.com/pytorch/hub/raw/master/dog.jpg", "dog.jpg")
- try:
- urllib.URLopener().retrieve(url, filename)
- except:
- urllib.request.urlretrieve(url, filename)
- from PIL import Image
- from torchvision import transforms
- input_image = Image.open(filename)
- # 构建处理图片的函数
- preprocess = transforms.Compose(
- [
- transforms.ToTensor(),
- transforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]),
- ]
- )
- input_tensor = preprocess(input_image)
- input_batch = input_tensor.unsqueeze(0) # 产生一个样本
- if torch.cuda.is_available():
- input_batch = input_batch.to("cuda")
- model.to("cuda")
- with torch.no_grad():
- output = model(input_batch)['out'][0]
- output_predictions = output.argmax(0)
- palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
- colors = torch.as_tensor([i for i in range(21)])[:, None] * colors
- colors = (colors % 255).numpy().astype("uint8")
- r = Image.fromarray(output_predictions.bytes().CPU().numpy()).resize(input_image.size)
- r.putpalette(colors)
- import matplotlib.pyplot as plt
- plt.show(r)
来源: http://www.bubuko.com/infodetail-3495811.html