Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions experimental/air/cmd/compute.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package aircmd

import (
"errors"
"fmt"
"strings"
)

// gpuType is a wire-facing accelerator type submitted to the training service.
// The number in the name is the partition count (e.g. GPU_8xH100 is 8 GPUs).
type gpuType string

const (
gpuType1xA10 gpuType = "GPU_1xA10"
gpuType8xH100 gpuType = "GPU_8xH100"
gpuType1xH100 gpuType = "GPU_1xH100"
)

// gpuTypes lists every valid type. Used for validation error messages.
var gpuTypes = []gpuType{gpuType1xA10, gpuType1xH100, gpuType8xH100}

func validGPUTypesHint() string {
names := make([]string, len(gpuTypes))
for i, g := range gpuTypes {
names[i] = string(g)
}
return "valid types are: " + strings.Join(names, ", ")
}

// parseGPUType resolves a YAML accelerator_type string to a gpuType. The match is
// exact: the server's lookup is case-sensitive.
func parseGPUType(value string) (gpuType, error) {
switch gpuType(value) {
case gpuType1xA10, gpuType8xH100, gpuType1xH100:
return gpuType(value), nil
}
return "", fmt.Errorf("invalid GPU type %q: %s", value, validGPUTypesHint())
}

// gpusPerNode returns the per-node GPU count, which is the partition count from
// the name (GPU_1xH100 -> 1, GPU_8xH100 -> 8). num_accelerators must be a
// round multiple of this since accelerators are allocated in whole nodes.
func gpusPerNode(g gpuType) (int, error) {
switch g {
case gpuType1xA10, gpuType1xH100:
return 1, nil
case gpuType8xH100:
return 8, nil
}
return 0, fmt.Errorf("invalid GPU type %q", string(g))
}

// computeConfig is the `compute` block of the run YAML: which accelerators to
// use and how many.
type computeConfig struct {
NumAccelerators int `yaml:"num_accelerators"`
AcceleratorType string `yaml:"accelerator_type"`
NodePoolID string `yaml:"node_pool_id"`
PoolName string `yaml:"pool_name"`
}

// validate checks the compute block against the backend's constraints.
func (c computeConfig) validate() error {
g, err := parseGPUType(c.AcceleratorType)
if err != nil {
return fmt.Errorf("compute.accelerator_type: %w", err)
}

if c.NumAccelerators <= 0 {
return fmt.Errorf("compute.num_accelerators must be positive, got %d", c.NumAccelerators)
}

perNode, err := gpusPerNode(g)
if err != nil {
return err
}
if c.NumAccelerators%perNode != 0 {
return fmt.Errorf("compute.num_accelerators for %s must be a multiple of %d, got %d", c.AcceleratorType, perNode, c.NumAccelerators)
}

if c.NodePoolID != "" && c.PoolName != "" {
return errors.New("compute: cannot specify both node_pool_id and pool_name")
}

return nil
}
89 changes: 89 additions & 0 deletions experimental/air/cmd/compute_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package aircmd

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestParseGPUType(t *testing.T) {
tests := []struct {
in string
want gpuType
}{
{"GPU_1xA10", gpuType1xA10},
{"GPU_8xH100", gpuType8xH100},
{"GPU_1xH100", gpuType1xH100},
}
for _, tt := range tests {
t.Run(tt.in, func(t *testing.T) {
got, err := parseGPUType(tt.in)
require.NoError(t, err)
assert.Equal(t, tt.want, got)
})
}
}

func TestParseGPUTypeInvalid(t *testing.T) {
// Wrong casing is rejected rather than fixed up; legacy types (h100_80gb, a10)
// can no longer be submitted; unknown types are rejected.
for _, in := range []string{"gpu_1xa10", "GPU_1XA10", "GPU_2xH100", "h100_80gb", "a10", "b200", ""} {
t.Run(in, func(t *testing.T) {
_, err := parseGPUType(in)
require.Error(t, err)
assert.Contains(t, err.Error(), "valid types are")
})
}
}

func TestGPUsPerNode(t *testing.T) {
tests := []struct {
in gpuType
want int
}{
{gpuType1xA10, 1},
{gpuType1xH100, 1},
{gpuType8xH100, 8},
}
for _, tt := range tests {
t.Run(string(tt.in), func(t *testing.T) {
got, err := gpusPerNode(tt.in)
require.NoError(t, err)
assert.Equal(t, tt.want, got)
})
}

_, err := gpusPerNode(gpuType("nonsense"))
require.Error(t, err)
}

func TestComputeConfigValidate(t *testing.T) {
tests := []struct {
name string
cfg computeConfig
wantErr string // substring; empty means the config is valid
}{
{"single node", computeConfig{NumAccelerators: 8, AcceleratorType: "GPU_8xH100"}, ""},
{"multiple nodes", computeConfig{NumAccelerators: 16, AcceleratorType: "GPU_8xH100"}, ""},
{"single-gpu partitions", computeConfig{NumAccelerators: 3, AcceleratorType: "GPU_1xH100"}, ""},
{"with node pool", computeConfig{NumAccelerators: 1, AcceleratorType: "GPU_1xA10", NodePoolID: "pool-123"}, ""},
{"with pool name", computeConfig{NumAccelerators: 1, AcceleratorType: "GPU_1xA10", PoolName: "my-pool"}, ""},
{"unknown type", computeConfig{NumAccelerators: 8, AcceleratorType: "b200"}, "accelerator_type"},
{"legacy type rejected", computeConfig{NumAccelerators: 8, AcceleratorType: "h100_80gb"}, "accelerator_type"},
{"non-positive count", computeConfig{NumAccelerators: 0, AcceleratorType: "GPU_1xH100"}, "must be positive"},
{"count not a multiple", computeConfig{NumAccelerators: 4, AcceleratorType: "GPU_8xH100"}, "multiple of 8"},
{"both pool fields", computeConfig{NumAccelerators: 1, AcceleratorType: "GPU_1xA10", NodePoolID: "p", PoolName: "n"}, "both"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.cfg.validate()
if tt.wantErr == "" {
require.NoError(t, err)
return
}
require.Error(t, err)
assert.Contains(t, err.Error(), tt.wantErr)
})
}
}
Loading