Spaces:
Sleeping
Sleeping
File size: 4,732 Bytes
e337fdb 623fea8 7168a17 623fea8 7168a17 623fea8 e337fdb 623fea8 e337fdb 7168a17 b6a884b 7168a17 b6a884b 7168a17 b6a884b 7168a17 b6a884b 7168a17 b6a884b 7168a17 b6a884b 7168a17 e337fdb 7168a17 e337fdb 7168a17 b6a884b 7168a17 b6a884b 7168a17 b6a884b e337fdb 7168a17 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import gradio as gr
import timm
class ImprovedMultiOutputModel(nn.Module):
"""Improved multi-output model with EfficientNet backbone."""
def __init__(self, num_object_classes, num_material_classes, backbone='efficientnet_b0'):
super(ImprovedMultiOutputModel, self).__init__()
# Use EfficientNet backbone
self.backbone = timm.create_model(backbone, pretrained=True, num_classes=0)
backbone_out_features = self.backbone.num_features
# Add attention mechanism
self.attention = nn.Sequential(
nn.Linear(backbone_out_features, 512),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(512, backbone_out_features),
nn.Sigmoid()
)
# Improved classification heads with dropout and batch norm
self.object_classifier = nn.Sequential(
nn.Linear(backbone_out_features, 1024),
nn.BatchNorm1d(1024),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(512, num_object_classes)
)
self.material_classifier = nn.Sequential(
nn.Linear(backbone_out_features, 1024),
nn.BatchNorm1d(1024),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(512, num_material_classes)
)
def forward(self, x):
# Extract features using backbone
features = self.backbone(x)
# Apply attention mechanism
attention_weights = self.attention(features)
features = features * attention_weights
# Get predictions for each attribute
object_pred = self.object_classifier(features)
material_pred = self.material_classifier(features)
return {
'object_name': object_pred,
'material': material_pred,
}
def get_val_transforms():
"""Get transforms for validation."""
return transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def load_model(model_path):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load(model_path, map_location=device)
label_mappings = checkpoint['label_mappings']
num_object_classes = len(label_mappings['object_name'])
num_material_classes = len(label_mappings['material'])
backbone = 'efficientnet_b0'
model = ImprovedMultiOutputModel(num_object_classes, num_material_classes, backbone)
model.load_state_dict(checkpoint['model_state_dict'], strict=False)
model.to(device)
model.eval()
return model, label_mappings
# Load models
models = {}
models['modelv1.pth'], label_mappings_v1 = load_model('modelv1.pth')
models['modelv2.pth'], label_mappings_v2 = load_model('modelv2.pth')
# Assume label_mappings are the same for both, use v1
label_mappings = label_mappings_v1
def predict(image, model_choice):
if image is None:
return "Please upload an image."
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models[model_choice]
transform = get_val_transforms()
image_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(image_tensor)
pred_obj = torch.argmax(outputs['object_name'], dim=1).item()
pred_mat = torch.argmax(outputs['material'], dim=1).item()
# Map IDs back to names
obj_name = [k for k, v in label_mappings['object_name'].items() if v == pred_obj][0]
mat_name = [k for k, v in label_mappings['material'].items() if v == pred_mat][0]
return f"Predicted Object: {obj_name}\nPredicted Material: {mat_name}"
# Create Gradio interface using Blocks
with gr.Blocks(title="Artifact Classification Model") as demo:
gr.Markdown("# Artifact Classification Model")
gr.Markdown("Upload an image to classify the object name and material.")
model_selector = gr.Dropdown(choices=['modelv1.pth', 'modelv2.pth'], label="Select Model", value='modelv1.pth')
with gr.Row():
input_image = gr.Image(type="pil", label="Upload an Image")
output_text = gr.Textbox(label="Predictions")
predict_btn = gr.Button("Predict")
predict_btn.click(fn=predict, inputs=[input_image, model_selector], outputs=output_text)
if __name__ == "__main__":
demo.launch() |