diff --git a/examples/ml_dtypes/README.md b/examples/ml_dtypes/README.md new file mode 100644 index 0000000..fcb7370 --- /dev/null +++ b/examples/ml_dtypes/README.md @@ -0,0 +1,18 @@ +# ml_dtypes Examples + +Each sub-directory contains a self-contained example. The order in +which the examples are to appear is specified in `order.json` (an +array of directory names in the expected order). + +In each example directory you'll find: + +* `config.toml` - must conform to the specification outlined here: + https://docs.pyscript.net/latest/user-guide/configuration/ This is + parsed and ultimately turned into a JSON representation as part of + the package's API object. +* `setup.py` - Python code for contextual and environmental setup, + NOT SEEN BY THE END USER, but is run before the `code.py` code is + evaluated. Allows us to create useful (IPython) shims, avoid + repeating boilerplate and whatnot. +* `code.py` - the actual code added to the editor which forms the + practical example of using the package. diff --git a/examples/ml_dtypes/bfloat16_basics/code.py b/examples/ml_dtypes/bfloat16_basics/code.py new file mode 100644 index 0000000..3b6ded7 --- /dev/null +++ b/examples/ml_dtypes/bfloat16_basics/code.py @@ -0,0 +1,53 @@ +""" +A first look at ml_dtypes: the bfloat16 dtype. + +ml_dtypes provides NumPy dtype extensions used widely in machine +learning (bfloat16, several float8 variants, sub-byte ints, etc.). +Importing the package also registers the dtypes with NumPy, so you +can use them anywhere a regular dtype is accepted. + +Docs: https://github.com/jax-ml/ml_dtypes +""" +import numpy as np +from ml_dtypes import bfloat16 +from IPython.core.display import display, HTML + +heading("1. Creating a bfloat16 array") +note( + "bfloat16 is a 16-bit float with the same exponent range as " + "float32 (8 bits) but only 7 mantissa bits. It trades precision " + "for range, which suits ML workloads." +) + +# Construct a bfloat16 array. Once ml_dtypes is imported, NumPy +# also recognizes the dtype by string name. +weights_f32 = rng.normal(size=8).astype(np.float32) +weights_bf16 = weights_f32.astype(bfloat16) + +note("The same eight values, side by side at different precisions:") +comparison = np.empty( + (3, 8), + dtype=object, +) +comparison[0] = [f"{x:.6f}" for x in weights_f32] +comparison[1] = [f"{x:.6f}" for x in weights_bf16.astype(np.float32)] +comparison[2] = [f"{float(a) - float(b):+.2e}" + for a, b in zip(weights_f32, weights_bf16.astype(np.float32))] +rows = "".join( + "
{np.dtype('bfloat16')}")
+note(f"itemsize: {np.dtype('bfloat16').itemsize} bytes")
diff --git a/examples/ml_dtypes/bfloat16_basics/config.toml b/examples/ml_dtypes/bfloat16_basics/config.toml
new file mode 100644
index 0000000..226c052
--- /dev/null
+++ b/examples/ml_dtypes/bfloat16_basics/config.toml
@@ -0,0 +1 @@
+packages = ["ml_dtypes", "numpy"]
diff --git a/examples/ml_dtypes/bfloat16_basics/setup.py b/examples/ml_dtypes/bfloat16_basics/setup.py
new file mode 100644
index 0000000..c3dcde2
--- /dev/null
+++ b/examples/ml_dtypes/bfloat16_basics/setup.py
@@ -0,0 +1,48 @@
+"""
+Shim IPython's display API onto PyScript so example code written in a
+Jupyter/IPython idiom runs unmodified in the browser.
+"""
+
+import sys
+import types
+import js
+from pyscript import window, HTML, display as _display
+
+js.alert = window.alert
+
+
+def display(*args, **kwargs):
+ """Wrap pyscript.display so output lands in the example target."""
+ return _display(
+ *args, **kwargs, target=__pyscript_display_target__,
+ )
+
+
+ipython = types.ModuleType("IPython")
+core = types.ModuleType("IPython.core")
+core_display = types.ModuleType("IPython.core.display")
+core_display.display = display
+core_display.HTML = HTML
+ipython.core = core
+core.display = core_display
+ipython.get_ipython = lambda: None
+ipython.display = core_display
+sys.modules["IPython"] = ipython
+sys.modules["IPython.core"] = core
+sys.modules["IPython.core.display"] = core_display
+sys.modules["IPython.display"] = core_display
+
+
+def heading(text, level=2):
+ display(HTML(f"{text}
"), append=True) + + +# Package imports for this example. +import numpy as np +from ml_dtypes import bfloat16 + +rng = np.random.default_rng(0) diff --git a/examples/ml_dtypes/float8_landscape/code.py b/examples/ml_dtypes/float8_landscape/code.py new file mode 100644 index 0000000..b47d100 --- /dev/null +++ b/examples/ml_dtypes/float8_landscape/code.py @@ -0,0 +1,80 @@ +# --------------------------------------------------------------------- +# A tour of the smaller dtypes: float8 variants and int4. +# --------------------------------------------------------------------- + +heading("How many distinct values does each format represent?") +note( + "float8_e4m3fn favors precision (4 exponent, 3 mantissa). " + "float8_e5m2 favors range (5 exponent, 2 mantissa). " + "Both fit in a single byte, but they cover the number line " + "very differently." +) + +# Enumerate every byte pattern for each 8-bit float type and decode +# it to float32. This gives us the complete value set of each format. +all_bytes = np.arange(256, dtype=np.uint8) + + +def decode_all(dtype): + """Return the float32 values of all 256 byte patterns for a dtype.""" + return all_bytes.view(dtype).astype(np.float32) + + +e4m3_values = decode_all(float8_e4m3fn) +e5m2_values = decode_all(float8_e5m2) + +# Drop NaNs for plotting. +e4m3_finite = e4m3_values[np.isfinite(e4m3_values)] +e5m2_finite = e5m2_values[np.isfinite(e5m2_values)] + +note( + f"float8_e4m3fn finite values: {len(e4m3_finite)}, " + f"max = {e4m3_finite.max():g}{text}
"), append=True) + + +import numpy as np +import matplotlib.pyplot as plt +from ml_dtypes import bfloat16, float8_e4m3fn, float8_e5m2, int4 + +rng = np.random.default_rng(0) diff --git a/examples/ml_dtypes/low_precision_summation/code.py b/examples/ml_dtypes/low_precision_summation/code.py new file mode 100644 index 0000000..df4c575 --- /dev/null +++ b/examples/ml_dtypes/low_precision_summation/code.py @@ -0,0 +1,57 @@ +# --------------------------------------------------------------------- +# Quirks of low-precision arithmetic: a cautionary tale. +# --------------------------------------------------------------------- + +heading("Summing 10,000 small numbers in bfloat16") +note( + "We draw 10,000 values uniformly from [0, 1). The true sum " + "is around 5,000. Watch what happens when bfloat16 accumulates." +) + +samples = rng.uniform(size=10_000).astype(bfloat16) + +# Naive: accumulate in bfloat16. Once the running total reaches 256, +# adding a value below 1 has no effect (the next representable +# bfloat16 above 256 is 258). +naive_sum = samples.sum() + +# Better: accumulate in float32 and cast back at the end. +careful_sum = samples.sum(dtype=np.float32).astype(bfloat16) + +# Reference: full float32 sum. +true_sum = samples.astype(np.float32).sum() + +note( + f"Naive bfloat16 sum: {float(naive_sum):.1f}{text}
"), append=True) + + +import numpy as np +import matplotlib.pyplot as plt +from ml_dtypes import bfloat16 + +rng = np.random.default_rng(0) diff --git a/examples/ml_dtypes/order.json b/examples/ml_dtypes/order.json new file mode 100644 index 0000000..3a0096f --- /dev/null +++ b/examples/ml_dtypes/order.json @@ -0,0 +1,5 @@ +[ + "bfloat16_basics", + "low_precision_summation", + "float8_landscape" +]