Spaces:
Runtime error
Runtime error
| from torchvision import transforms | |
| import torch | |
| import urllib | |
| from PIL import Image | |
| import gradio as gr | |
| import torch | |
| # Images | |
| torch.hub.download_url_to_file('https://images.pexels.com/photos/17811/pexels-photo.jpg', 'bird.jpg') | |
| model = torch.hub.load('nicolalandro/ntsnet-cub200', 'ntsnet', pretrained=True, | |
| **{'topN': 6, 'device':'cpu', 'num_classes': 200}) | |
| transform_test = transforms.Compose([ | |
| transforms.Resize((600, 600), Image.BILINEAR), | |
| transforms.CenterCrop((448, 448)), | |
| # transforms.RandomHorizontalFlip(), # only if train | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | |
| ]) | |
| model = torch.hub.load('nicolalandro/ntsnet-cub200', 'ntsnet', pretrained=True, **{'topN': 6, 'device':'cpu', 'num_classes': 200}) | |
| def birds(img): | |
| scaled_img = transform_test(img) | |
| torch_images = scaled_img.unsqueeze(0) | |
| with torch.no_grad(): | |
| top_n_coordinates, concat_out, raw_logits, concat_logits, part_logits, top_n_index, top_n_prob = model(torch_images) | |
| _, predict = torch.max(concat_logits, 1) | |
| pred_id = predict.item() | |
| return model.bird_classes[pred_id].split('.')[1] | |
| inputs = gr.inputs.Image(type='pil', label="Original Image") | |
| outputs = gr.outputs.Textbox(label="bird class") | |
| title = "ntsnet" | |
| description = "demo for ntsnet to classify birds. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below." | |
| article = "<p style='text-align: center'><a href='http://artelab.dista.uninsubria.it/res/research/papers/2019/2019-IVCNZ-Nawaz-Birds.pdf'>Are These Birds Similar: Learning Branched Networks for Fine-grained Representations</a> | <a href='https://github.com/nicolalandro/ntsnet-cub200'>Github Repo</a></p>" | |
| examples = [ | |
| ['bird.jpg'] | |
| ] | |
| gr.Interface(birds, inputs, outputs, title=title, description=description, | |
| article=article, examples=examples, analytics_enabled=False).launch() |