StevenLimcorn commited on
Commit
44b4267
1 Parent(s): 146b45b

Initial Commit

Browse files
Files changed (3) hide show
  1. app.py +40 -0
  2. model.pth +3 -0
  3. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch.nn.functional as F
4
+ import torch
5
+ from torchvision import transforms
6
+
7
+ model = torch.load("/content/drive/MyDrive/Mask Detection/model.pth", map_location=torch.device("cpu"))
8
+ IMG_SIZE = 224
9
+ MASK_LABEL = ["Mask worn properly.", "Mask not worn properly: nose out", "Mask not worn properly: chin and nose out", "Didn't wear mask."]
10
+
11
+ transforms_test = transforms.Compose(
12
+ [
13
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
14
+ transforms.ToTensor(),
15
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
16
+ ]
17
+ )
18
+
19
+ MASK_LABEL = ["Mask worn properly.", "Mask not worn properly: nose out", "Mask not worn properly: chin and nose out", "Didn't wear mask."]
20
+
21
+ def predict_image(image):
22
+ transformed_tensor = torch.unsqueeze(transforms_test(image), 0)
23
+ logits = model(transformed_tensor)
24
+ probability = torch.flatten(F.softmax(logits, dim=1)).detach().cpu().numpy()
25
+ print(probability)
26
+ labels = {A: B.item() for A, B in zip(MASK_LABEL, probability)}
27
+ sorted_labels = dict(sorted(labels.items(), key=lambda item: item[1], reverse=True))
28
+ print(sorted_labels)
29
+ return sorted_labels
30
+
31
+ title = "ViT Mask Detection"
32
+ description = "Gradio demo for ViT-16 Mask Image Classification created by <a href='https://github.com/stevenlimcorn'>Steven Limcorn</a>"
33
+ article = "An Application made by stevenlimcorn. Notebook access at: <a href='https://github.com/stevenlimcorn/Mask-Classification'></a>"
34
+
35
+ demo = gr.Interface(predict_image,
36
+ inputs=gr.Image(label="Input Image", type="pil", source="webcam"),
37
+ outputs=gr.Label(), title=title, description=description, article=article
38
+ )
39
+
40
+ demo.launch()
model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c2cee68dc4777f9133fe97ccca2414e66d20f628fdef4efcef99bfac9408259b
3
+ size 343285383
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ timm==0.4.12
2
+ torch==1.10.1
3
+ gradio
4
+ numpy
5
+ torchvision