Gradio 案例升级
1. 文本分类—垃圾邮件分类
import gradio as gr
from sklearn.feature_extraction.text import CountVectorizer
import re
import zhconv
import jieba.posseg as psg
import pickle
import joblib
def clean_data(email):
# 1.去除非中文字符
email = re.sub(r'[^\u4e00-\u9fa5]', '', email)
# 2.繁体转简体
email = zhconv.convert(email, 'zh-cn')
# 3.邮件词性筛选
email_pos = psg.cut(email)
allow_pos = ['n', 'nr', 'ns', 'nt', 'v', 'a']
email = []
for word, pos in email_pos:
if pos in allow_pos:
email.append(word)
# 4.转换成 str 类型
email = ' '.join(email)
return email
def email_handle(text):
# 1.对数据进行清理
content = clean_data(text)
# 2.数据特征提取
vocab = pickle.load(open('03-模型训练特征.pkl', 'rb'))
transfer = CountVectorizer(vocabulary=vocab)
content = transfer.transform([content])
# 3.模型加载
model = joblib.load('04-邮件分类模型.pth')
output = model.predict(content)
prediction = output[0]
prediction = '垃圾邮件' if prediction == 'spam' else '正常邮件'
return prediction
demo = gr.Interface(fn=email_handle, inputs="text", outputs="label")
gr.close_all()
demo.launch()
2. 图像分类
from torchvision import transforms
import gradio as gr
import torch
import torchvision.models as models
model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).eval()
# 打开文本文件
file_path = 'labels.txt' # 将文件路径替换为你实际的文本文件路径
with open(file_path, 'r') as file:
# 读取文件内容
labels = file.readlines()
# 去掉每个label的换行符
labels = [label.rstrip() for label in labels]
# 预测函数
def predict(inp):
inp = transforms.ToTensor()(inp).unsqueeze(0)
with torch.no_grad():
prediction = torch.nn.functional.softmax(model(inp)[0], dim=0)
confidences = {labels[i]: float(prediction[i]) for i in range(1000)}
return confidences
# fn表示触发函数,当我们点击提交按钮时,触发predict函数进行推理
# inputs表示图像会被转换为PIL.Image格式
# outputs将以label的形式展示出来
# examples可以将图片作为示例展示出来,供用户使用
gr.Interface(fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=3),
examples=["demo/cat.jpg", "demo/tiger.jpeg"]).launch()
3. 图片筛选器
尽管gradio的设计初衷是为了快速创建机器学习用户交互页面。但实际上,通过组合gradio的各种组件,用户可以很方便地实现非常实用的各种应用小工具。
例如: 数据分析展示dashboard, 数据标注工具, 制作一个小游戏界面等等。
本范例我们将应用 gradio来构建一个图片筛选器,从百度爬取的一堆猫咪表情包中刷选一些我们喜欢的出来。
#!pip install -U torchkeras
import torchkeras
from torchkeras.data import download_baidu_pictures
# download_baidu_pictures('猫咪表情包', 100)
import gradio as gr
from PIL import Image
import time, os
from pathlib import Path
base_dir = '猫咪表情包'
selected_dir = 'selected'
files = [str(x) for x in
Path(base_dir).rglob('*.jp*g')
if 'checkpoint' not in str(x)]
def show_img(path):
return Image.open(path)
def fn_next(done, todo):
# 构建排序关键字
# 根据已查看和未查看的数量进行排序
sort_key = lambda file: (int(os.path.splitext(file)[0]) - done, -todo)
# 对数据集中的图像进行排序
sorted_images = sorted(os.listdir(base_dir), key=sort_key)
img_sort = sorted_images[done + 1]
path = os.path.join(base_dir, img_sort)
done += 1
todo -= 1
img = show_img(path)
return done, todo, path, img, msg
def save_selected(img_path):
img = Image.open(img_path)
img = img.convert("RGB")
img.save(os.path.join(selected_dir, img_path))
msg = "Image saved successfully!"
return msg
def get_default_msg():
msg = "图像未筛选状态!"
return msg
with gr.Blocks() as demo:
with gr.Row():
total = gr.Number(len(files), label='总数量')
with gr.Row():
done = gr.Number(0, label='已完成')
todo = gr.Number(len(files), label='待完成')
bn_next = gr.Button("下一张")
path = gr.Text(files[0], lines=1, label='当前图片路径')
feedback_button = gr.Button("选择图片", variant="primary")
msg = gr.TextArea(value=get_default_msg(), lines=3, max_lines=5)
img = gr.Image(value=show_img(files[0]), type='pil')
bn_next.click(fn_next,
inputs=[done, todo],
outputs=[done, todo, path, img, msg])
feedback_button.click(save_selected,
inputs=path,
outputs=msg)
demo.launch()