Instructions to use kernels-community/megablocks with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use kernels-community/megablocks with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("kernels-community/megablocks") - Notebooks
- Google Colab
- Kaggle
| # Copyright 2024 Databricks | |
| # SPDX-License-Identifier: Apache-2.0 | |
| import numpy as np | |
| import pytest | |
| import torch | |
| try: | |
| from megablocks._ops import ops as backend # type: ignore | |
| except ModuleNotFoundError as e: | |
| raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e | |
| from megablocks import ops | |
| def promote_scalar(x: torch.Tensor) -> torch.Tensor: | |
| return x.view(1) if not len(x.size()) else x | |
| REPLICATE_TESTS = [ | |
| (8, 1, 1), | |
| (8, 2, 1), | |
| (8, 4, 1), | |
| (8, 8, 1), | |
| (8, 2, 2), | |
| (8, 4, 2), | |
| (8, 8, 2), | |
| (8, 2, 4), | |
| (8, 4, 4), | |
| (8, 8, 4), | |
| (8, 2, 8), | |
| (8, 4, 8), | |
| (8, 8, 8), | |
| (16384, 2, 1), | |
| (16384, 4, 1), | |
| (16384, 8, 1), | |
| (16384, 16, 1), | |
| (16384, 32, 1), | |
| (16384, 64, 1), | |
| (16384, 128, 1), | |
| (16384, 2, 2), | |
| (16384, 4, 2), | |
| (16384, 8, 2), | |
| (16384, 16, 2), | |
| (16384, 32, 2), | |
| (16384, 64, 2), | |
| (16384, 128, 2), | |
| (16384, 2, 4), | |
| (16384, 4, 4), | |
| (16384, 8, 4), | |
| (16384, 16, 4), | |
| (16384, 32, 4), | |
| (16384, 64, 4), | |
| (16384, 128, 4), | |
| (16384, 2, 8), | |
| (16384, 4, 8), | |
| (16384, 8, 8), | |
| (16384, 16, 8), | |
| (16384, 32, 8), | |
| (16384, 64, 8), | |
| (16384, 128, 8), | |
| ] | |
| def test_replicate(tokens: int, num_centers: int, top_k: int): | |
| tokens_to_centers = torch.randint(0, num_centers, (tokens,)).cuda().int() | |
| tokens_per_center = ops.histogram(tokens_to_centers, num_centers) | |
| bins = ops.inclusive_cumsum(tokens_per_center, 0) | |
| bins = promote_scalar(bins) | |
| center_weights = torch.randn(top_k, num_centers).cuda().half() | |
| def replicate(x: torch.Tensor, bins: torch.Tensor, num_outputs: int): | |
| x = x.cpu().numpy() | |
| bins = bins.cpu().numpy() | |
| out = np.zeros((x.shape[0], num_outputs)) | |
| for batch_idx in range(x.shape[0]): | |
| start = 0 | |
| for i, end in enumerate(bins): | |
| value = x[batch_idx, i] | |
| while start < end: | |
| out[batch_idx, start] = value | |
| start += 1 | |
| return torch.from_numpy(out).cuda().half() | |
| out = ops.replicate(center_weights, bins, tokens) | |
| expected_out = replicate(center_weights, bins, tokens) | |
| assert torch.all(torch.eq(out, expected_out)) | |
| def test_replicate_backward(tokens: int, num_centers: int, top_k: int): | |
| tokens_to_centers = torch.randint(0, num_centers, (tokens,)).cuda().int() | |
| tokens_per_center = ops.histogram(tokens_to_centers, num_centers) | |
| bins = ops.inclusive_cumsum(tokens_per_center, 0) | |
| bins = promote_scalar(bins) | |
| center_weights = torch.randn(top_k, num_centers).cuda().half() | |
| grad = ops.replicate(center_weights, bins, tokens) | |
| out = torch.empty_like(center_weights) | |
| backend.replicate_backward(grad, bins, out) | |
| expected_out = center_weights * tokens_per_center.view([1, num_centers]) | |
| # NOTE: This floating-point reduction could be a problem for training stability and accuracy. | |
| assert torch.allclose(out, expected_out, rtol=1e-2) | |