Image Segmentation
Transformers
Safetensors
cxr_basic
feature-extraction
chest_x_ray
x_ray
medical_imaging
radiology
segmentation
classification
lungs
heart
custom_code
Instructions to use Dimaodessa/chest-x-ray-basic with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Dimaodessa/chest-x-ray-basic with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("image-segmentation", model="Dimaodessa/chest-x-ray-basic", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Dimaodessa/chest-x-ray-basic", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import albumentations as A | |
| import torch | |
| import torch.nn as nn | |
| from numpy.typing import NDArray | |
| from transformers import PreTrainedModel | |
| from timm import create_model | |
| from typing import Optional | |
| from .configuration import CXRConfig | |
| from .unet import UnetDecoder, SegmentationHead | |
| _PYDICOM_AVAILABLE = False | |
| try: | |
| from pydicom import dcmread | |
| from pydicom.pixels import apply_voi_lut | |
| _PYDICOM_AVAILABLE = True | |
| except ModuleNotFoundError: | |
| pass | |
| class CXRModel(PreTrainedModel): | |
| config_class = CXRConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.encoder = create_model( | |
| model_name=config.backbone, | |
| features_only=True, | |
| pretrained=False, | |
| in_chans=config.in_chans, | |
| ) | |
| self.decoder = UnetDecoder( | |
| decoder_n_blocks=config.decoder_n_blocks, | |
| decoder_channels=config.decoder_channels, | |
| encoder_channels=config.encoder_channels, | |
| decoder_center_block=config.decoder_center_block, | |
| decoder_norm_layer=config.decoder_norm_layer, | |
| decoder_attention_type=config.decoder_attention_type, | |
| ) | |
| self.img_size = config.img_size | |
| self.segmentation_head = SegmentationHead( | |
| in_channels=config.decoder_channels[-1], | |
| out_channels=config.seg_num_classes, | |
| size=self.img_size, | |
| ) | |
| self.pooling = nn.AdaptiveAvgPool2d(1) | |
| self.dropout = nn.Dropout(p=config.cls_dropout) | |
| self.classifier = nn.Linear(config.feature_dim, config.cls_num_classes) | |
| def normalize(self, x: torch.Tensor) -> torch.Tensor: | |
| # [0, 255] -> [-1, 1] | |
| mini, maxi = 0.0, 255.0 | |
| x = (x - mini) / (maxi - mini) | |
| x = (x - 0.5) * 2.0 | |
| return x | |
| def load_image_from_dicom(path: str) -> Optional[NDArray]: | |
| if not _PYDICOM_AVAILABLE: | |
| print("`pydicom` is not installed, returning None ...") | |
| return None | |
| dicom = dcmread(path) | |
| arr = apply_voi_lut(dicom.pixel_array, dicom) | |
| if dicom.PhotometricInterpretation == "MONOCHROME1": | |
| # invert image if needed | |
| arr = arr.max() - arr | |
| arr = arr - arr.min() | |
| arr = arr / arr.max() | |
| arr = (arr * 255).astype("uint8") | |
| return arr | |
| def preprocess(self, x: NDArray) -> NDArray: | |
| x = A.Resize(self.img_size[0], self.img_size[1], p=1)(image=x)["image"] | |
| return x | |
| def forward(self, x: torch.Tensor, return_logits: bool = False) -> torch.Tensor: | |
| x = self.normalize(x) | |
| features = self.encoder(x) | |
| decoder_output = self.decoder(features) | |
| logits = self.segmentation_head(decoder_output[-1]) | |
| b, n = features[-1].shape[:2] | |
| features = self.pooling(features[-1]).reshape(b, n) | |
| features = self.dropout(features) | |
| cls_logits = self.classifier(features) | |
| out = { | |
| "mask": logits, | |
| "age": cls_logits[:, 0].unsqueeze(1), | |
| "view": cls_logits[:, 1:4], | |
| "female": cls_logits[:, 4].unsqueeze(1), | |
| } | |
| if return_logits: | |
| return out | |
| out["mask"] = out["mask"].softmax(1) | |
| out["view"] = out["view"].softmax(1) | |
| out["female"] = out["female"].sigmoid() | |
| return out | |