-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgradio_app.py
52 lines (44 loc) · 2.02 KB
/
gradio_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import gradio as gr
from PIL import Image
import sys
import yaml
import os
import numpy as np
import torch
from torchvision.models import resnet50
# Append the path to the src/xai_method directory to the system path
sys.path.append(os.path.join(os.path.dirname(__file__), 'src', 'xai_methods'))
# Import the necessary functions from cam.py and the DFF script
from cam import initModel, processImage, visualize_image as visualize_gradcam
from dff import visualize_image as visualize_dff # Assuming your DFF script is named dff.py
# Load configuration
with open("./CONFIG.yaml") as f:
cfg = yaml.load(f, Loader=yaml.FullLoader)
def gradio_interface(image_path, method, n_components=5, top_k=2):
if method in ["gradcam", "hirescam"]:
# Call the visualize_image function from cam.py for Grad-CAM and HiResCAM
result_image = visualize_gradcam(image_path, method)
elif method == "dff":
# Call the visualize_image function from dff.py for Deep Feature Factorization
model = resnet50(pretrained=True)
model.eval()
result_image = visualize_dff(model, image_path, n_components, top_k)
else:
raise ValueError("Invalid method. Use 'gradcam', 'hirescam', or 'dff'.")
return result_image
image_width = 400
image_height = 400
iface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Image(type="filepath", label="Upload Image", width=image_width, height=image_height),
gr.Radio(choices=["gradcam", "hirescam", "dff"], label="Method"),
gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Number of Components (DFF only)"),
gr.Slider(minimum=1, maximum=5, step=1, value=2, label="Top K (DFF only)")
],
outputs=gr.Image(type="pil", label="Output Image", width=image_width, height=image_height),
title="Explainable AI for Image Classification",
description="Upload an image, select a method (Grad-CAM, HiResCAM, or DFF), and specify the output filename to generate the explanation overlay.",
)
if __name__ == "__main__":
iface.launch()