Spaces:
Sleeping
Sleeping
| """Visualization helpers for single-case PanCancerSeg inference.""" | |
| from pathlib import Path | |
| import cv2 | |
| import numpy as np | |
| import SimpleITK as sitk | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| DEFAULT_OVERLAY_COLOR = (255, 0, 0) | |
| def preprocess_volume(volume, wl, ww): | |
| """Apply CT windowing and return uint8 data in [0, 255].""" | |
| volume = volume.astype(np.float32, copy=False) | |
| lower_bound = wl - ww / 2 | |
| upper_bound = wl + ww / 2 | |
| clipped = np.clip(volume, lower_bound, upper_bound) | |
| return _normalize_to_uint8(clipped) | |
| def overlay_mask(gray_slice, mask_slice, color=DEFAULT_OVERLAY_COLOR, alpha=0.5): | |
| """Apply a semi-transparent RGB overlay to one grayscale slice.""" | |
| gray_slice = np.asarray(gray_slice, dtype=np.uint8) | |
| if gray_slice.ndim != 2: | |
| raise ValueError(f"Expected a 2D grayscale slice, got shape {gray_slice.shape}") | |
| rgb = np.stack([gray_slice] * 3, axis=-1) | |
| mask = mask_slice > 0 | |
| if not np.any(mask): | |
| return rgb | |
| out = rgb.copy() | |
| color_arr = np.asarray(color, dtype=np.float32) | |
| blended = out[mask].astype(np.float32) * (1 - alpha) + color_arr * alpha | |
| out[mask] = np.clip(blended, 0, 255).astype(np.uint8) | |
| return out | |
| def find_key_slices(mask_vol): | |
| """Return named representative z-slices for a mask in [z, y, x] order.""" | |
| if mask_vol.ndim != 3: | |
| raise ValueError(f"Expected a 3D mask volume, got shape {mask_vol.shape}") | |
| depth = mask_vol.shape[0] | |
| if depth == 0: | |
| raise ValueError("Cannot select key slices from an empty z-dimension") | |
| mask = mask_vol > 0 | |
| if np.any(mask): | |
| z_indices = np.where(np.any(mask, axis=(1, 2)))[0] | |
| areas = mask.reshape(depth, -1).sum(axis=1) | |
| coords = np.argwhere(mask) | |
| centroid_z = int(round(float(coords[:, 0].mean()))) | |
| min_z = int(z_indices.min()) | |
| max_z = int(z_indices.max()) | |
| return { | |
| "centroid": _clip_slice(centroid_z, depth), | |
| "max_area": int(areas.argmax()), | |
| "extent25": _clip_slice(round(min_z + 0.25 * (max_z - min_z)), depth), | |
| "extent75": _clip_slice(round(min_z + 0.75 * (max_z - min_z)), depth), | |
| } | |
| middle = depth // 2 | |
| offset = max(1, depth // 10) | |
| return { | |
| "centroid": middle, | |
| "max_area": _clip_slice(middle - offset, depth), | |
| "extent25": _clip_slice(middle + offset, depth), | |
| "extent75": _clip_slice(middle + 2 * offset, depth), | |
| } | |
| def generate_slice_images( | |
| image_uint8, | |
| mask_vol, | |
| output_dir, | |
| case_name, | |
| color=DEFAULT_OVERLAY_COLOR, | |
| alpha=0.5, | |
| ): | |
| """Save side-by-side PNGs for representative slices.""" | |
| output_dir = Path(output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| key_slices = find_key_slices(mask_vol) | |
| outputs = {} | |
| for label, z_idx in key_slices.items(): | |
| gray_slice = image_uint8[z_idx] | |
| mask_slice = mask_vol[z_idx] > 0 | |
| overlay = overlay_mask(gray_slice, mask_slice, color=color, alpha=alpha) | |
| fig, axes = plt.subplots(1, 2, figsize=(10, 5), dpi=150) | |
| axes[0].imshow(gray_slice, cmap="gray", vmin=0, vmax=255) | |
| axes[0].set_title("Image") | |
| axes[0].axis("off") | |
| axes[1].imshow(overlay) | |
| axes[1].set_title("Segmentation overlay") | |
| axes[1].axis("off") | |
| fig.suptitle(f"{case_name} | z={z_idx}") | |
| fig.tight_layout() | |
| out_path = output_dir / f"{case_name}_slice_{label}.png" | |
| fig.savefig(out_path, dpi=150, bbox_inches="tight", facecolor="white") | |
| plt.close(fig) | |
| outputs[label] = out_path | |
| return outputs | |
| def generate_video( | |
| image_uint8, | |
| mask_vol, | |
| output_dir, | |
| case_name, | |
| cancer_type, | |
| color=DEFAULT_OVERLAY_COLOR, | |
| alpha=0.5, | |
| fps=10, | |
| ): | |
| """Generate an MP4 scroll-through overlay video.""" | |
| output_dir = Path(output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| video_path = output_dir / f"{case_name}_overlay.mp4" | |
| start_z, end_z = _video_z_range(mask_vol) | |
| first_frame = _make_video_frame( | |
| image_uint8[start_z], | |
| mask_vol[start_z], | |
| color, | |
| alpha, | |
| start_z, | |
| image_uint8.shape[0], | |
| cancer_type, | |
| ) | |
| height, width = first_frame.shape[:2] | |
| writer = _open_video_writer(video_path, fps, width, height) | |
| # Frame annotations are drawn in RGB space; convert only when writing to OpenCV. | |
| writer.write(cv2.cvtColor(first_frame, cv2.COLOR_RGB2BGR)) | |
| for z_idx in range(start_z + 1, end_z + 1): | |
| frame = _make_video_frame( | |
| image_uint8[z_idx], | |
| mask_vol[z_idx], | |
| color, | |
| alpha, | |
| z_idx, | |
| image_uint8.shape[0], | |
| cancer_type, | |
| ) | |
| writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) | |
| writer.release() | |
| return video_path | |
| def generate_outputs( | |
| image_path, | |
| mask_path, | |
| output_dir, | |
| case_name, | |
| cancer_type, | |
| wl, | |
| ww, | |
| color=DEFAULT_OVERLAY_COLOR, | |
| alpha=0.5, | |
| fps=10, | |
| ): | |
| """Read image and mask volumes, then write PNG previews and MP4 video.""" | |
| image = sitk.ReadImage(str(image_path)) | |
| mask = sitk.ReadImage(str(mask_path)) | |
| image_vol = sitk.GetArrayFromImage(image) | |
| mask_vol = sitk.GetArrayFromImage(mask) | |
| if image_vol.shape != mask_vol.shape: | |
| raise ValueError( | |
| "Image and segmentation shapes do not match: " | |
| f"image={image_vol.shape}, segmentation={mask_vol.shape}. " | |
| "Both arrays are expected in [z, y, x] order." | |
| ) | |
| image_uint8 = preprocess_volume(image_vol, wl, ww) | |
| slice_paths = generate_slice_images( | |
| image_uint8, | |
| mask_vol, | |
| output_dir, | |
| case_name, | |
| color, | |
| alpha, | |
| ) | |
| video_path = generate_video( | |
| image_uint8, | |
| mask_vol, | |
| output_dir, | |
| case_name, | |
| cancer_type, | |
| color, | |
| alpha, | |
| fps, | |
| ) | |
| return {"slices": slice_paths, "video": video_path} | |
| def _normalize_to_uint8(volume): | |
| v_min = float(np.min(volume)) | |
| v_max = float(np.max(volume)) | |
| if not np.isfinite(v_min) or not np.isfinite(v_max) or v_max <= v_min: | |
| return np.zeros(volume.shape, dtype=np.uint8) | |
| normalized = (volume - v_min) / (v_max - v_min) * 255.0 | |
| return np.clip(normalized, 0, 255).astype(np.uint8) | |
| def _clip_slice(index, depth): | |
| return int(np.clip(index, 0, depth - 1)) | |
| def _video_z_range(mask_vol, padding=10, empty_window=80): | |
| depth = mask_vol.shape[0] | |
| mask = mask_vol > 0 | |
| if np.any(mask): | |
| z_indices = np.where(np.any(mask, axis=(1, 2)))[0] | |
| return ( | |
| max(0, int(z_indices.min()) - padding), | |
| min(depth - 1, int(z_indices.max()) + padding), | |
| ) | |
| if depth <= empty_window: | |
| return 0, depth - 1 | |
| middle = depth // 2 | |
| half = empty_window // 2 | |
| return max(0, middle - half), min(depth - 1, middle + half) | |
| def _make_video_frame(gray_slice, mask_slice, color, alpha, z_idx, depth, cancer_type): | |
| frame = overlay_mask(gray_slice, mask_slice, color=color, alpha=alpha) | |
| frame = _upscale_if_small(frame) | |
| annotation = f"Slice {z_idx + 1}/{depth} | {cancer_type}" | |
| font = cv2.FONT_HERSHEY_SIMPLEX | |
| font_scale = max(0.6, min(frame.shape[:2]) / 900) | |
| thickness = max(1, int(round(font_scale * 2))) | |
| text_size, baseline = cv2.getTextSize(annotation, font, font_scale, thickness) | |
| x, y = 12, 12 + text_size[1] | |
| cv2.rectangle( | |
| frame, | |
| (x - 6, y - text_size[1] - 6), | |
| (x + text_size[0] + 6, y + baseline + 6), | |
| (0, 0, 0), | |
| thickness=-1, | |
| ) | |
| cv2.putText(frame, annotation, (x, y), font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA) | |
| return frame | |
| def _upscale_if_small(frame, min_short_side=512): | |
| height, width = frame.shape[:2] | |
| short_side = min(height, width) | |
| if short_side >= min_short_side: | |
| return frame | |
| scale = min_short_side / short_side | |
| new_size = (int(round(width * scale)), int(round(height * scale))) | |
| return cv2.resize(frame, new_size, interpolation=cv2.INTER_LINEAR) | |
| def _open_video_writer(video_path, fps, width, height): | |
| attempts = [ | |
| ("avc1", "H.264/avc1"), | |
| ("mp4v", "MPEG-4/mp4v"), | |
| ] | |
| for fourcc_text, label in attempts: | |
| fourcc = cv2.VideoWriter_fourcc(*fourcc_text) | |
| writer = cv2.VideoWriter(str(video_path), fourcc, float(fps), (width, height)) | |
| if writer.isOpened(): | |
| return writer | |
| writer.release() | |
| raise RuntimeError( | |
| f"Could not open MP4 writer at {video_path}. Tried " | |
| + ", ".join(label for _, label in attempts) | |
| + ". Install an OpenCV build with MP4 codec support or try another machine." | |
| ) | |