Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions api/experimentation/stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""Bayesian statistics for experiment results.

Compares a treatment variant against control and reports, in plain terms:
how much better/worse the treatment did (``lift``), how sure we are
(``chance_to_win`` and the credible interval), and whether traffic was split
fairly between variants (``srm_p_value``). All inputs are summary numbers; no
raw events reach this module.
"""

import math
from collections.abc import Sequence
from dataclasses import dataclass
from statistics import NormalDist

_STANDARD_NORMAL = NormalDist()
# A 95% interval spans the mean ± 1.96 standard deviations of a normal curve.
_Z_95 = 1.959963984540054


@dataclass(frozen=True)
class VariantStats:
"""Everything we need to know about one variant, as three running totals.

For a conversion metric each identity contributes 0 or 1, so ``sum`` is the
conversion count; for a value metric (e.g. revenue) it is the total. These
three numbers are enough to recover the average and the spread, so the
warehouse never has to send per-identity rows.
"""

n: int # identities in the variant
sum: float # total of their per-identity values
sum_squares: float # total of the squares, used to derive the spread

@property
def mean(self) -> float:
return self.sum / self.n

@property
def variance(self) -> float:
# Spread of the per-identity values. max(0, …) guards against tiny
# negative results from floating-point error when every value is equal.
return max(0.0, (self.sum_squares - self.sum**2 / self.n) / (self.n - 1))


@dataclass(frozen=True)
class Inference:
lift: float # relative change vs control, e.g. 0.12 == +12%
ci_low: float # credible interval: we're 95% sure the true lift
ci_high: float # lies between ci_low and ci_high
chance_to_win: float # probability (0–1) the treatment really beats control


def compare_to_control(
control: VariantStats,
treatment: VariantStats,
) -> Inference | None:
# Inference is undefined without two observations per arm (no spread to
# measure) or a non-positive control mean (relative lift against it is
# meaningless, and divides by zero when the mean is exactly zero).
if control.n < 2 or treatment.n < 2 or control.mean <= 0:
return None

lift = (treatment.mean - control.mean) / control.mean
# How uncertain that lift is. Both arms are noisy, so the uncertainty of a
# ratio combines both; the delta method is the standard approximation, and
# the arms being independent means there is no covariance term.
variance = treatment.variance / (
treatment.n * control.mean**2
) + treatment.mean**2 * control.variance / (control.n * control.mean**4)
standard_error = math.sqrt(variance)
if standard_error == 0:
# No uncertainty (every value identical): the result is exact.
certainty = 0.5 if lift == 0 else float(lift > 0)
return Inference(lift=lift, ci_low=lift, ci_high=lift, chance_to_win=certainty)

# Treat the true lift as a normal curve centred on `lift`, `standard_error`
# wide. The interval is its middle 95%; chance_to_win is the share of the
# curve above zero (i.e. how much of our belief says the treatment is up).
return Inference(
lift=lift,
ci_low=lift - _Z_95 * standard_error,
ci_high=lift + _Z_95 * standard_error,
chance_to_win=_STANDARD_NORMAL.cdf(lift / standard_error),
)


def srm_p_value(
observed: Sequence[int],
expected_shares: Sequence[float],
) -> float | None:
"""Sample ratio mismatch check: was traffic split as configured?

Returns the probability that random assignment alone would drift from the
configured split at least as much as we observed. A tiny value (< 0.001 by
convention) means the split is broken and the results can't be trusted.
``None`` when the question is meaningless (no traffic, one variant).
"""
total = sum(observed)
if len(observed) < 2 or total == 0 or any(s <= 0 for s in expected_shares):
return None

# Chi-squared statistic: total squared gap between observed and expected
# counts, scaled by what's expected. Bigger gap == bigger number.
statistic = sum(
(count - total * share) ** 2 / (total * share)
for count, share in zip(observed, expected_shares, strict=True)
)
return _chi_squared_survival(statistic, degrees_of_freedom=len(observed) - 1)


def _chi_squared_survival(statistic: float, degrees_of_freedom: int) -> float:
# Turns the chi-squared statistic into a probability (the p-value above):
# how likely a gap this large is by chance. 0 gap → certain (1.0).
if statistic <= 0:
return 1.0
# The standard library has no chi-squared distribution, but for integer
# degrees of freedom the survival function is exact from the base cases
# Q(1/2, y) = erfc(√y) and Q(1, y) = e⁻ʸ via the recurrence
# Q(a+1, y) = Q(a, y) + yᵃe⁻ʸ/Γ(a+1).
y = statistic / 2.0
if degrees_of_freedom % 2:
a = 0.5
survival = math.erfc(math.sqrt(y))
else:
a = 1.0
survival = math.exp(-y)
while a + 1.0 <= degrees_of_freedom / 2.0:
survival += math.exp(a * math.log(y) - y - math.lgamma(a + 1.0))
a += 1.0
return survival
183 changes: 183 additions & 0 deletions api/tests/unit/experimentation/test_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
import pytest

from experimentation.stats import (
Inference,
VariantStats,
compare_to_control,
srm_p_value,
)


def test_variant_stats__sufficient_statistics__derive_mean_and_variance() -> None:
# Given 1000 identities with 100 conversions (0/1 values)
stats = VariantStats(n=1000, sum=100.0, sum_squares=100.0)

# When / Then
assert stats.mean == 0.1
assert stats.variance == pytest.approx(90.0 / 999.0)


def test_variant_stats__float_noise__variance_clamped_to_zero() -> None:
# Given sums whose rounding puts the raw variance just below zero
stats = VariantStats(n=2, sum=2.0, sum_squares=1.9999999999999996)

# When / Then
assert stats.variance == 0.0


def test_compare_to_control__more_conversions__positive_lift_inference() -> None:
# Given a 10% control and a 12% treatment, 1000 identities each
control = VariantStats(n=1000, sum=100.0, sum_squares=100.0)
treatment = VariantStats(n=1000, sum=120.0, sum_squares=120.0)

# When
inference = compare_to_control(control, treatment)

# Then
assert inference is not None
assert inference.lift == pytest.approx(0.2)
assert inference.ci_low == pytest.approx(-0.10074, abs=1e-4)
assert inference.ci_high == pytest.approx(0.50074, abs=1e-4)
assert inference.chance_to_win == pytest.approx(0.90379, abs=1e-4)


def test_compare_to_control__identical_arms__chance_is_even() -> None:
# Given two arms with the same conversions
arm = VariantStats(n=1000, sum=100.0, sum_squares=100.0)

# When
inference = compare_to_control(arm, arm)

# Then
assert inference is not None
assert inference.lift == 0.0
assert inference.chance_to_win == 0.5
assert inference.ci_low == pytest.approx(-inference.ci_high)


@pytest.mark.parametrize(
"treatment, expected",
[
(
VariantStats(n=10, sum=20.0, sum_squares=40.0),
Inference(lift=1.0, ci_low=1.0, ci_high=1.0, chance_to_win=1.0),
),
(
VariantStats(n=10, sum=5.0, sum_squares=2.5),
Inference(lift=-0.5, ci_low=-0.5, ci_high=-0.5, chance_to_win=0.0),
),
(
VariantStats(n=10, sum=10.0, sum_squares=10.0),
Inference(lift=0.0, ci_low=0.0, ci_high=0.0, chance_to_win=0.5),
),
],
ids=["better", "worse", "equal"],
)
def test_compare_to_control__zero_variance_arms__degenerate_certainty(
treatment: VariantStats,
expected: Inference,
) -> None:
# Given arms with constant values (zero variance)
control = VariantStats(n=10, sum=10.0, sum_squares=10.0)

# When / Then
assert compare_to_control(control, treatment) == expected


def test_compare_to_control__zero_control_mean__returns_none() -> None:
# Given a control with no conversions: relative lift is undefined
control = VariantStats(n=1000, sum=0.0, sum_squares=0.0)
treatment = VariantStats(n=1000, sum=120.0, sum_squares=120.0)

# When / Then
assert compare_to_control(control, treatment) is None


def test_compare_to_control__negative_control_mean__returns_none() -> None:
# Given a control whose values average below zero (e.g. a revenue metric
# with refunds): relative lift against it is meaningless
control = VariantStats(n=1000, sum=-50.0, sum_squares=600.0)
treatment = VariantStats(n=1000, sum=120.0, sum_squares=120.0)

# When / Then
assert compare_to_control(control, treatment) is None


@pytest.mark.parametrize(
"control_n, treatment_n",
[(1, 1000), (1000, 1), (0, 1000)],
ids=["control_too_small", "treatment_too_small", "control_empty"],
)
def test_compare_to_control__insufficient_observations__returns_none(
control_n: int,
treatment_n: int,
) -> None:
# Given an arm with fewer than two observations: variance is undefined
control = VariantStats(
n=control_n, sum=float(control_n), sum_squares=float(control_n)
)
treatment = VariantStats(
n=treatment_n, sum=float(treatment_n), sum_squares=float(treatment_n)
)

# When / Then
assert compare_to_control(control, treatment) is None


def test_srm_p_value__balanced_split__no_mismatch() -> None:
# Given observed counts exactly matching the expected 50/50 split
# When
p_value = srm_p_value([5000, 5000], [0.5, 0.5])

# Then
assert p_value == pytest.approx(1.0)


@pytest.mark.parametrize(
"observed, shares, expected_p",
[
# chi-squared = 4.0, 1 dof
([5100, 4900], [0.5, 0.5], 0.04550),
# chi-squared = 2.0, 2 dof: survival is exp(-1)
([3400, 3300, 3300], [1 / 3, 1 / 3, 1 / 3], 0.36788),
# chi-squared = 8.0, 3 dof
([2600, 2500, 2500, 2400], [0.25, 0.25, 0.25, 0.25], 0.04601),
],
ids=["one_dof", "two_dof", "three_dof"],
)
def test_srm_p_value__known_chi_squared__matches_reference(
observed: list[int],
shares: list[float],
expected_p: float,
) -> None:
# Given observed counts with a hand-computed chi-squared statistic
# When / Then
assert srm_p_value(observed, shares) == pytest.approx(expected_p, abs=1e-4)


def test_srm_p_value__heavy_imbalance__fails_threshold() -> None:
# Given a 60/40 observed split against an expected 50/50
# When
p_value = srm_p_value([6000, 4000], [0.5, 0.5])

# Then
assert p_value is not None
assert p_value < 0.001


@pytest.mark.parametrize(
"observed, shares",
[
([0, 0], [0.5, 0.5]),
([5000, 5000], [1.0, 0.0]),
([10000], [1.0]),
],
ids=["no_observations", "zero_share", "single_variant"],
)
def test_srm_p_value__not_computable__returns_none(
observed: list[int],
shares: list[float],
) -> None:
# Given inputs the chi-squared test is undefined for
# When / Then
assert srm_p_value(observed, shares) is None
Loading
Loading