Documentation
¶
Overview ¶
Package stablehlo helps build a ToStableHLO program (text format) to then be JIT-compiled and executed by PJRT (github.com/gomlx/gopjrt/pjrt).
Among its features:
- Translates an API to rendered (human-readable) ToStableHLO text. - Shape inference: it calculates the output shapes for operations. - Written purely in Go, no C/C++ external dependencies.
It was written as a replacement for `gopjrt/xlabuilder` and attempts to keep a similar or identical interface.
See ToStableHLO documentation and specifications in https://openxla.org/stablehlo/spec
Index ¶
- Constants
- func BatchNormGradient(operand, scale, mean, variance, gradOutput *Value, epsilon float32, ...) (gradOperand *Value, gradScale *Value, gradOffset *Value, err error)
- func BatchNormTraining(operand, scale, offset *Value, epsilon float32, featureAxis int) (normalized *Value, batchMean *Value, batchVariance *Value, err error)
- func ConvertToValidName(name string) string
- func NormalizeIdentifier(name string) string
- type Builder
- func (b *Builder) Build() ([]byte, error)
- func (b *Builder) Main(inputs ...*Value) *Function
- func (b *Builder) Meshes() []*shardy.DeviceMesh
- func (b *Builder) NewFunction(name string, inputs ...*Value) *Function
- func (b *Builder) NewShardingSpec() *shardy.ShardingSpec
- func (b *Builder) NewShardingSpecByMeshIx(meshIdx int) *shardy.ShardingSpec
- func (b *Builder) WithNumPartitions(n int) *Builder
- func (b *Builder) WithNumReplicas(n int) *Builder
- func (b *Builder) WithShardy(meshes ...*shardy.DeviceMesh) *Builder
- func (b *Builder) Write(writer io.Writer) error
- type Computation
- type DotGeneralBuilder
- func (b *DotGeneralBuilder) Algorithm(algorithm *types.DotGeneralAlgorithm) *DotGeneralBuilder
- func (b *DotGeneralBuilder) Done() (*Value, error)
- func (b *DotGeneralBuilder) OutputDType(dtype dtypes.DType) *DotGeneralBuilder
- func (b *DotGeneralBuilder) Precision(lhsPrecision, rhsPrecision types.DotGeneralPrecisionType) *DotGeneralBuilder
- type Function
- func (fn *Function) Closure() *Function
- func (fn *Function) ConstantFromFlatAndDimensions(flat any, dimensions ...int) (*Value, error)
- func (fn *Function) ConstantFromScalar(value any) (*Value, error)
- func (fn *Function) Input(shape shapes.Shape) (*Value, error)
- func (fn *Function) InputWithAttributes(shape shapes.Shape, attributes map[string]any) (*Value, error)
- func (fn *Function) InputWithSharding(shape shapes.Shape, shardingSpec *shardy.ShardingSpec) (*Value, error)
- func (fn *Function) InputWithShardingAndAttributes(shape shapes.Shape, shardingSpec *shardy.ShardingSpec, ...) (*Value, error)
- func (fn *Function) Iota(shape shapes.Shape, axis int) (*Value, error)
- func (fn *Function) NamedInput(name string, shape shapes.Shape) (*Value, error)
- func (fn *Function) NamedInputWithAttributes(name string, shape shapes.Shape, attributes map[string]any) (*Value, error)
- func (fn *Function) NamedInputWithSharding(name string, shape shapes.Shape, shardingSpec *shardy.ShardingSpec) (*Value, error)
- func (fn *Function) NamedInputWithShardingAndAttributes(name string, shape shapes.Shape, shardingSpec *shardy.ShardingSpec, ...) (*Value, error)
- func (fn *Function) Return(values ...*Value) error
- func (fn *Function) ReturnWithAttributes(values []*Value, attributes []map[string]any) error
- func (fn *Function) ReturnWithShardingAndAttributes(values []*Value, shardingSpecs []*shardy.ShardingSpec, ...) error
- func (fn *Function) Write(writer io.Writer, indentation string) error
- type Statement
- type Value
- func Abs(operand *Value) (*Value, error)
- func Add(lhs, rhs *Value) (*Value, error)
- func AllGather(operand *Value, replicaGroups [][]int, allGatherDim int, ...) (*Value, error)
- func AllReduce(operands []*Value, replicaGroups [][]int, computation *Function, ...) ([]*Value, error)
- func AllToAll(operand *Value, replicaGroups [][]int, ...) (*Value, error)
- func And(lhs, rhs *Value) (*Value, error)
- func Atan2(lhs, rhs *Value) (*Value, error)
- func BatchNormInference(operand, scale, offset, mean, variance *Value, epsilon float32, ...) (*Value, error)
- func BitcastConvert(operand *Value, targetDtype dtypes.DType) (*Value, error)
- func BroadcastInDim(operand *Value, target shapes.Shape, axesMapping []int) (*Value, error)
- func Cbrt(operand *Value) (*Value, error)
- func Ceil(operand *Value) (*Value, error)
- func Clamp(min, x, max *Value) (*Value, error)
- func CollectiveBroadcast(operand *Value, replicaGroups [][]int, config ...*types.CollectiveConfig) (*Value, error)
- func CollectivePermute(operand *Value, sourceTargetPairs [][2]int, config ...*types.CollectiveConfig) (*Value, error)
- func Compare(lhs, rhs *Value, direction types.ComparisonDirection, ...) (*Value, error)
- func Complex(real, imag *Value) (*Value, error)
- func Concatenate(axis int, operands ...*Value) (*Value, error)
- func Convert(x *Value, dtype dtypes.DType) (*Value, error)
- func Convolution(input, kernel *Value, strides []int, paddings [][2]int, ...) (*Value, error)
- func Cosine(operand *Value) (*Value, error)
- func CountLeadingZeros(operand *Value) (*Value, error)
- func Divide(lhs, rhs *Value) (*Value, error)
- func Dot(lhs, rhs *Value) (*Value, error)
- func DynamicSlice(operand *Value, startIndices []*Value, sliceSizes []int) (*Value, error)
- func DynamicUpdateSlice(operand, update *Value, startIndices []*Value) (*Value, error)
- func Erf(operand *Value) (*Value, error)
- func Exponential(operand *Value) (*Value, error)
- func ExponentialMinusOne(operand *Value) (*Value, error)
- func FFT(x *Value, fftType types.FFTType, fftLength ...int) (*Value, error)
- func Floor(operand *Value) (*Value, error)
- func Gather(operand, startIndices *Value, indexVectorAxis int, ...) (*Value, error)
- func Imag(complex *Value) (*Value, error)
- func IsFinite(x *Value) (*Value, error)
- func Log(operand *Value) (*Value, error)
- func LogPlusOne(operand *Value) (*Value, error)
- func Logistic(operand *Value) (*Value, error)
- func Maximum(lhs, rhs *Value) (*Value, error)
- func Minimum(lhs, rhs *Value) (*Value, error)
- func MultiReduce(inputs, initialValues []*Value, reductionFn *Function, axes ...int) ([]*Value, error)
- func MultiReduceWindow(inputs, initialValues []*Value, reductionFn *Function, ...) ([]*Value, error)
- func MultiScatter(inputs []*Value, scatterIndices *Value, updates []*Value, ...) ([]*Value, error)
- func Multiply(lhs, rhs *Value) (*Value, error)
- func Negate(operand *Value) (*Value, error)
- func Not(operand *Value) (*Value, error)
- func Or(lhs, rhs *Value) (*Value, error)
- func Pad(x, fill *Value, paddingStart, paddingEnd, paddingInterior []int) (*Value, error)
- func Popcnt(operand *Value) (*Value, error)
- func Power(lhs, rhs *Value) (*Value, error)
- func RNGBitGenerator(state *Value, shape shapes.Shape, algorithm types.RNGBitGeneratorAlgorithm) (newState, values *Value, err error)
- func Real(complex *Value) (*Value, error)
- func Reduce(x, initialValue *Value, reductionFn *Function, axes ...int) (*Value, error)
- func ReduceWindow(input, initialValue *Value, reductionFn *Function, ...) (*Value, error)
- func Remainder(lhs, rhs *Value) (*Value, error)
- func Reshape(operand *Value, shape shapes.Shape) (*Value, error)
- func Reverse(x *Value, axes ...int) (*Value, error)
- func RoundNearestAfz(operand *Value) (*Value, error)
- func RoundNearestEven(operand *Value) (*Value, error)
- func Rsqrt(operand *Value) (*Value, error)
- func Scatter(input, scatterIndices, updates *Value, ...) (*Value, error)
- func Select(pred, onTrue, onFalse *Value) (*Value, error)
- func SelectAndScatter(input, scatterSource, initialValue *Value, selectFn, scatterFn *Function, ...) (*Value, error)
- func ShiftLeft(lhs, rhs *Value) (*Value, error)
- func ShiftRightArithmetic(lhs, rhs *Value) (*Value, error)
- func ShiftRightLogical(lhs, rhs *Value) (*Value, error)
- func Sign(operand *Value) (*Value, error)
- func Sine(operand *Value) (*Value, error)
- func Slice(x *Value, starts, limits, strides []int) (*Value, error)
- func Sqrt(operand *Value) (*Value, error)
- func Subtract(lhs, rhs *Value) (*Value, error)
- func Tan(operand *Value) (*Value, error)
- func Tanh(operand *Value) (*Value, error)
- func Transpose(x *Value, permutation ...int) (*Value, error)
- func Xor(lhs, rhs *Value) (*Value, error)
Constants ¶
const IndentationStep = " "
const MainFunctionName = "main"
Variables ¶
This section is empty.
Functions ¶
func BatchNormGradient ¶
func BatchNormGradient(operand, scale, mean, variance, gradOutput *Value, epsilon float32, featureAxis int) (gradOperand *Value, gradScale *Value, gradOffset *Value, err error)
BatchNormGradient calculates the batch normalization gradients with respect to the input, scale, and offset. https://openxla.org/xla/operation_semantics#batchnormgrad
The gradOutput is the adjoint gradient (the "V" in "VJP"), that is, the gradient with respect to the output of the batch normalization.
Based on the paper "Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" (Sergey Ioffe, Christian Szegedy), https://arxiv.org/abs/1502.03167.
func BatchNormTraining ¶
func BatchNormTraining(operand, scale, offset *Value, epsilon float32, featureAxis int) (normalized *Value, batchMean *Value, batchVariance *Value, err error)
BatchNormTraining implements batch normalization for training. See details in https://www.tensorflow.org/xla/operation_semantics#batchnormtraining.
It returns the normalized tensor, the batch mean and variance.
Based on the paper "Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" (Sergey Ioffe, Christian Szegedy), https://arxiv.org/abs/1502.03167.
func ConvertToValidName ¶
ConvertToValidName replaces any characters not in { "0"-"9", "a"-"z", "A-Z", "_" } to a "_", making it a valid name for values and function arguments.
func NormalizeIdentifier ¶
NormalizeIdentifier converts the name of an identifier (function name or function input parameter name, etc.) to a valid one: only letters, digits, and underscores are allowed.
Invalid characters are replaced with underscores. If the name starts with a digit, it is prefixed with an underscore.
Types ¶
type Builder ¶
type Builder struct {
// contains filtered or unexported fields
}
Builder is used to construct a StableHLO program (or "Module") See details in New.
func New ¶
New creates a new Builder object holding a computation graph in construction.
From a builder you can create functions. For each function you create operations (ops) one by one, until you defined the desired computation.
You have to define the "main" function for your StableHLO program: you can use Builder.Main to do so, or Builder.NewFunction("main",...), it's the same.
Once you are all set, call Builder.Build and it will return the StableHLO program (or "Module") as a []byte that can be used with PJRT.
See github.com/gomlx/gopjrt for a Go API to PJRT.
func (*Builder) Build ¶
Build checks the validity and builds the StableHLO program.
If you want the output of an incomplete program (without the checking), use Builder.Write instead.
func (*Builder) Main ¶
Main creates the main function of the program. It is an alias to Builder.NewFunction("main", inputs...).
The main function is the entry point of the program, and it's the only function that can be called from outside the program.
Every program must have a main function.
Like with NewFunction, you can add new inputs later by calling Function.Input.
func (*Builder) Meshes ¶ added in v0.2.0
func (b *Builder) Meshes() []*shardy.DeviceMesh
Meshes returns the meshes configured with WithShardy.
func (*Builder) NewFunction ¶
NewFunction creates a new function and adds it to the program. The function outputs will be determined by the last statement in the function body.
The function name must be unique in the program.
The inputs are the values that the function will receive as arguments. The values are not added to the program, they are just used as inputs.
You can also add new inputs later by calling Function.Input.
The function body is defined by calling ops on the function object.
See Function.
func (*Builder) NewShardingSpec ¶ added in v0.2.0
func (b *Builder) NewShardingSpec() *shardy.ShardingSpec
NewShardingSpec creates a new ShardingSpec using the first mesh configured with WithShardy. It returns nil if no mesh was not configured.
This is a shortcut to NewShardingSpecByMeshIx(0).
func (*Builder) NewShardingSpecByMeshIx ¶ added in v0.2.0
func (b *Builder) NewShardingSpecByMeshIx(meshIdx int) *shardy.ShardingSpec
NewShardingSpecByMeshIx creates a new ShardingSpec for the meshIdx (the order given by WithShardy).
It may return nil if meshIdx is out of range.
func (*Builder) WithNumPartitions ¶ added in v0.1.0
WithNumPartitions sets the number of partitions (for model parallelism). This is added as an attribute to the StableHLO module.
Consider using WithShardy for distributed computation instead: other forms of distributed (collective) computation across devices are not tested and may not work.
func (*Builder) WithNumReplicas ¶ added in v0.1.0
WithNumReplicas sets the number of replicas (for data parallelism). This is added as an attribute to the StableHLO module.
Consider using WithShardy for distributed computation instead: other forms of distributed (collective) computation across devices are not tested and may not work.
func (*Builder) WithShardy ¶ added in v0.2.0
func (b *Builder) WithShardy(meshes ...*shardy.DeviceMesh) *Builder
WithShardy enables distributed computation across the devices selected by the given meshes.
This is the recommended way to do distributed (across devices) computation, and given the inputs with sharded information, Shardy will automatically distribute the computation, without you needing to specify any of the collective operations.
Usually, there is only one meshes. But one can split the devices in different meshes. The meshes overlap the concrete devices used.
See details of XLA Shardy in [1]
[1] https://github.com/openxla/shardy
type Computation ¶
Computation holds a rendered computation graph, that can be fed to PJRT. It is created with Builder.Build.
type DotGeneralBuilder ¶
type DotGeneralBuilder struct {
// contains filtered or unexported fields
}
DotGeneralBuilder is a builder for DotGeneral nodes. See DotGeneral for more details.
func DotGeneral ¶
func DotGeneral( lhsOp *Value, lhsContractingAxes, lhsBatchAxes []int, rhsOp *Value, rhsContractingAxes, rhsBatchAxes []int) *DotGeneralBuilder
DotGeneral takes as input lhs (left-hand-side) and rhs (right-hand-side) specifications for a general vector product -- a generalized "Einsum". Each axis can be:
- Just aligned (batch axes), so the output has the same axes as the inputs. The dimensions must match in lhs and rhs.
- Crossed (default), in which case the output is the combination (concatenation) of the dimensions.
- Contracted (contracting axes), where the output does multiply the values and reduce sum those dimensions.
It follows that the resulting dimension number starts with the batch dimension, then the 'lhs' non-contracting/non-batch dimension, and finally the 'rhs' non-contracting/non-batch dimension. It provides the basic means of implementing Einsum.
Because there are optional parameters, this function returns a DotGeneralBuilder that can be further configured. Call DotGeneralBuilder.Done to get the final DotGeneral node.
func (*DotGeneralBuilder) Algorithm ¶
func (b *DotGeneralBuilder) Algorithm(algorithm *types.DotGeneralAlgorithm) *DotGeneralBuilder
Algorithm sets the algorithm settings to use for the dot-general operation.
The default is not to set any of these parameters.
See details in types.DotGeneralAlgorithm.
func (*DotGeneralBuilder) Done ¶
func (b *DotGeneralBuilder) Done() (*Value, error)
Done indicates the end of the DotGeneralBuilder configuration. It checks the validity of the parameters and shapes and returns the final DotGeneral node.
func (*DotGeneralBuilder) OutputDType ¶
func (b *DotGeneralBuilder) OutputDType(dtype dtypes.DType) *DotGeneralBuilder
OutputDType sets the output data type: for input types like BFloat16 one may want to increase the output precision.
func (*DotGeneralBuilder) Precision ¶
func (b *DotGeneralBuilder) Precision(lhsPrecision, rhsPrecision types.DotGeneralPrecisionType) *DotGeneralBuilder
Precision sets the precision of the dot-general operation.
Its default is described as "the fastest calculation, but the least accurate approximation to the original number."
It controls the tradeoff between speed and accuracy for computations on accelerator backends. This can be one of the following (at the moment, the semantics of these enum values are underspecified, but they are planning to address this in #755 -- https://github.com/openxla/stablehlo/issues/755):
type Function ¶
type Function struct {
Builder *Builder
// Name of the function. It should not include the "@" prefix.
Name string
// Inputs to the function.
Inputs []*Value
// Outputs of the function.
Outputs []*Value
// Statements in the function body.
Statements []*Statement
// Parent of a closure function. It is only set if the function is a closure, and it's the function that created it.
Parent *Function
// Returned indicates if the function has a return statement, so it can no longer be changed.
Returned bool
// contains filtered or unexported fields
}
Function represents a `func.func` in ToStableHLO.
func (*Function) Closure ¶
Closure creates an unnamed closure function that can be used as an argument to operations like Reduce, ReduceWindow, ScatterAndUpdate, etc.
After created, the Closure should not be changed. But it can be used multiple times within the same parent function.
The function body is defined by calling ops on the function object, as a usual Function object.
func (*Function) ConstantFromFlatAndDimensions ¶
ConstantFromFlatAndDimensions creates a new constant statement from a flat slice with the raw values and the dimensions of the shape.
func (*Function) ConstantFromScalar ¶
ConstantFromScalar creates a new constant statement and returns the resulting value.
func (*Function) Input ¶
Input creates a new input parameter for a function.
If creating multiple inputs (one at a time), the order matters, since during execution of a compiled function, the input parameters must be given in the same order they were created.
These add to the inputs already created during the function creation.
It picks a default unique name for the input parameter, you can also provide a name with NamedInput.
func (*Function) InputWithAttributes ¶ added in v0.2.0
func (fn *Function) InputWithAttributes(shape shapes.Shape, attributes map[string]any) (*Value, error)
InputWithAttributes creates a new input with the given attributes.
func (*Function) InputWithSharding ¶ added in v0.2.0
func (fn *Function) InputWithSharding(shape shapes.Shape, shardingSpec *shardy.ShardingSpec) (*Value, error)
InputWithSharding creates a new input with the given sharding specification.
func (*Function) InputWithShardingAndAttributes ¶ added in v0.2.0
func (fn *Function) InputWithShardingAndAttributes(shape shapes.Shape, shardingSpec *shardy.ShardingSpec, attributes map[string]any) (*Value, error)
InputWithShardingAndAttributes creates a new input with the given sharding specification and attributes.
func (*Function) Iota ¶
Iota creates a constant of the given shape with increasing numbers (starting from 0) on the given axis. So Iota([2,2], 1) returns [[0 1][0 1]], while Iota([2,2], 0) returns [[0 0][1 1]].
func (*Function) NamedInput ¶
NamedInput creates a new input parameter for a function with the given name -- it must be a unique input name.
The name is passed through ConvertToValidName, which converts any non-digit or ASCII letter to an underscore.
Names with the format "%d" and "arg%d" are reserved for the default input parameters.
Names are used in the StableHLO code and may be helpful for debugging, but otherwise have no impact.
func (*Function) NamedInputWithAttributes ¶ added in v0.2.0
func (fn *Function) NamedInputWithAttributes(name string, shape shapes.Shape, attributes map[string]any) (*Value, error)
NamedInputWithAttributes creates a new input parameter for a function with the given name and attributes.
func (*Function) NamedInputWithSharding ¶ added in v0.2.0
func (fn *Function) NamedInputWithSharding(name string, shape shapes.Shape, shardingSpec *shardy.ShardingSpec) (*Value, error)
NamedInputWithSharding creates a new input parameter for a function with the given name -- it must be a unique input name -- and sharding specification for distributed computation.
func (*Function) NamedInputWithShardingAndAttributes ¶ added in v0.2.0
func (fn *Function) NamedInputWithShardingAndAttributes(name string, shape shapes.Shape, shardingSpec *shardy.ShardingSpec, attributes map[string]any) (*Value, error)
NamedInputWithShardingAndAttributes creates a new input parameter for a function with the given name -- it must be a unique input name -- and sharding specification for distributed computation.
The shardingSpec can be nil: the default is a replicated input across all devices.
The name is passed through ConvertToValidName, which converts any non-digit or ASCII letter to an underscore.
Names with the format "%d" and "arg%d" are reserved for the default input parameters.
Names are used in the StableHLO code and may be helpful for debugging, but otherwise have no impact.
func (*Function) Return ¶
Return adds a return statement to the function with the given return values. There must be at least one return value.
There can be only one return statement from a Function, and it must be the last operation of a function.
If you are doing distributed computation, you can use WithReturnShardingSpecs to specify the sharding requirements for each of the return values.
func (*Function) ReturnWithAttributes ¶ added in v0.2.0
ReturnWithAttributes adds a return statement to the function with the given return values and attributes.
func (*Function) ReturnWithShardingAndAttributes ¶ added in v0.2.0
func (fn *Function) ReturnWithShardingAndAttributes(values []*Value, shardingSpecs []*shardy.ShardingSpec, attributes []map[string]any) error
ReturnWithShardingAndAttributes is a convenience function to call ReturnWithAttributes with the given sharding specifications.
The shardingSpecs slice of ShardingSpecs must have the same length as the values slice. Each ShardingSpec can be nil, in which case the default sharding is replicated across all devices. If shardingSpecs is nil, this behaves just like ReturnWithAttributes.
The attributes slice of maps can be set to nil if there are no attributes.
type Statement ¶
type Statement struct {
Builder *Builder
Function *Function
// OpType is the type of the operation.
OpType optypes.OpType
// Inputs to the operation.
Inputs []*Value
// Attributes of the operation.
Attributes map[string]any
// FunctionParameters for statements with operations like Reduce, ReduceWindow, ScatterAndUpdate, etc.
FunctionParameters []*Function
FunctionParametersNames []string
// Outputs of the operation. It may be nil for operations like func.return.
Outputs []*Value
}
Statement represents a single operation line in ToStableHLO.
func (*Statement) AddFunctionParameter ¶
type Value ¶
Value represents a value in a StableHLO program, like `%0` or `%arg0`. These values can be inputs, outputs or intermediary values of functions.
It is always associated with a function (where it's being used) and must be uniquely identified by a string with digits '0'-'9', 'A'-'Z', 'a'-'z' or '_'.
For inlined functions (for instance, the one passed to a Reduce operation), the names cannot clash with the parent function name (!?). But the names can be reused in different inline functions.
It also carries its shape information.
func AllGather ¶ added in v0.1.0
func AllGather(operand *Value, replicaGroups [][]int, allGatherDim int, config ...*types.CollectiveConfig) (*Value, error)
AllGather concatenates the operand from each replica along a specified dimension.
- operand: The tensor from the *local* replica to be gathered.
- replicaGroups: A 2D array defining the communicating device groups.
- allGatherDim: The dimension along which to concatenate the operands.
- config: Optional configuration of the channels to be used.
Consider using Builder.WithShardy for distributed computation instead: other forms of distributed (collective) computation across devices are not tested and may not work.
func AllReduce ¶ added in v0.1.0
func AllReduce(operands []*Value, replicaGroups [][]int, computation *Function, config ...*types.CollectiveConfig) ( []*Value, error)
AllReduce performs a distributed reduce operation across replicas. It is a distributed version of Reduce.
- operands: The tensors from the *local* replica to be reduced.
- replicaGroups: A 2D array defining the communicating device groups, e.g., `[[0, 1, 2, 3]]`.
- computation: A closure function that defines the reduction operation (e.g., SUM). It must take two scalar inputs for each operand's dtype and return one scalar output of the same dtype.
- replicaGroups: A 2D array defining the communicating device groups. For standard data parallelism, this is typically a single group with all the replica numbers -- notice it's not the device numbers by the replica numbers (there is an indirection). Except if the config sets UseGlobalDeviceIDs, in which case they are interpreted as device numbers. E.g., `[[0, 1, 2, 3]]`.
- config: Optional configuration of the channels to be used. This is not needed for SPMD programs.
Consider using Builder.WithShardy for distributed computation instead: other forms of distributed (collective) computation across devices are not tested and may not work.
func AllToAll ¶ added in v0.1.0
func AllToAll(operand *Value, replicaGroups [][]int, splitDimension, concatDimension, splitCount int, config ...*types.CollectiveConfig) (*Value, error)
AllToAll splits the operand along a specified dimension and scatters the chunks to all replicas, where they are concatenated back together.
- operand: The tensor from the *local* replica.
- replicaGroups: A 2D array defining the communicating device groups.
- splitDimension: The dimension along which to split the operand.
- concatDimension: The dimension along which to concatenate the received chunks.
- splitCount: The number of chunks to split the operand into. This must match the size of the replica groups.
- config: Optional configuration of the channels to be used.
Consider using Builder.WithShardy for distributed computation instead: other forms of distributed (collective) computation across devices are not tested and may not work.
func BatchNormInference ¶
func BatchNormInference(operand, scale, offset, mean, variance *Value, epsilon float32, featureAxis int) (*Value, error)
BatchNormInference implements batch normalization for inference. See details in https://www.tensorflow.org/xla/operation_semantics#batchnorminference.
Based on the paper "Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" (Sergey Ioffe, Christian Szegedy), https://arxiv.org/abs/1502.03167.
func BitcastConvert ¶
BitcastConvert performs an elementwise bit-cast operation from a dtype to another dtype.
The Bitcast doesn't "convert", rather it just reinterprets the bits from x.DType() to the targetDType.
If x.DType() and targetDType use the same number of bytes (targetDType.Size() == x.DType().Size()), the dimensions are not changed, simply the dtype is changed.
If targetDType.Size() > x.DType().Size(), it requires x last axis to have a dimension of targetDType.Size() / x.DType().Size(), and the returned shape will trim the last axis.
If targetDType.Size() < x.DType().Size(), the returned shape will have an extra axis in the end, with dimension of x.DType().Size() / targetDType.Size().
E.g: Bitcast([1]uint32{0xdeadbeef}, dtypes.UInt16) -> [1][2]uint16{{0xbeef, 0xdead}} // Little-endian encoding.
func BroadcastInDim ¶
BroadcastInDim broadcasts dimensions from the operand to the target shape. It can also transpose axes and add new ones.
The axesMapping should have one value per operand axes. It maps the axes from the operand to the corresponding value on the target shape.
func Clamp ¶
Clamp returns the minimum(maximum(x, min), max).
The values max and min can either be a scalar or have the same shape as x.
Clamp is not defined for booleans or complex numbers (the semantics would not be clear).
Note: the order of the arguments in StableHLO is different from most ML libraries.
func CollectiveBroadcast ¶ added in v0.1.0
func CollectiveBroadcast(operand *Value, replicaGroups [][]int, config ...*types.CollectiveConfig) (*Value, error)
CollectiveBroadcast broadcasts the value from the first replica (in each group) to all others. The returned shape is the same as the source. Devices not included in any replica group will return zeros as their output (with the same shape as the input).
- operand: The tensor to be broadcasted. In an SPMD setup, this op will be called on all replicas, but only the operand from the source device (the first device in the replica_group) will be used.
- replicaGroups: A 2D array defining the communicating device groups. For standard data parallelism, this is typically a single group with all the replica numbers -- notice it's not the device numbers by the replica numbers (there is an indirection). Except if the config sets UseGlobalDeviceIDs, in which case they are interpreted as device numbers. E.g., `[[0, 1, 2, 3]]`.
- config: Optional configuration of the channels to be used. This is shouldn't be used for SPMD programs.
Consider using Builder.WithShardy for distributed computation instead: other forms of distributed (collective) computation across devices are not tested and may not work.
func CollectivePermute ¶ added in v0.1.0
func CollectivePermute(operand *Value, sourceTargetPairs [][2]int, config ...*types.CollectiveConfig) (*Value, error)
CollectivePermute sends the operand from a source replica to a target replica.
- operand: The tensor from the *local* replica.
- sourceTargetPairs: A 2D array where each inner array is a `[source, target]` pair of replica IDs.
- config: Optional configuration of the channels to be used.
Consider using Builder.WithShardy for distributed computation instead: other forms of distributed (collective) computation across devices are not tested and may not work.
func Compare ¶
func Compare(lhs, rhs *Value, direction types.ComparisonDirection, compareType types.ComparisonType) (*Value, error)
Compare implements the corresponding standard binary operation.
For boolean data types (dtypes.Bool) use the types.CompareUnsigned type.
func Complex ¶
Complex returns the complex value by concatenating the real and imaginary parts element-wise.
func Concatenate ¶
Concatenate operands on the given axis.
All axes that are not being concatenated must match dimensions, except on the axes being concatenated. It doesn't work with scalars -- use ExpandAxes. If there is only one operand, it is returned and this is a no-op.
func Convert ¶
Convert x to the given dtype.
For boolean to numeric conversions, false becomes 0 and true 1.
For complex to non-complex conversions, the imaginary part is discarded (or set to 0).
Currently, it doesn't work for quantized to/from regular tensors. Use UniformQuantize and UniformDequantize for that.
func Convolution ¶
func Convolution(input, kernel *Value, strides []int, paddings [][2]int, inputDilations, kernelDilations []int, inputBatchAxis, inputChannelsAxis int, inputSpatialAxes []int, kernelInputChannelsAxis, kernelOutputChannelsAxis int, kernelSpatialAxes []int, outputBatchAxis, outputChannelsAxis int, outputSpatialAxes []int, channelGroupCount, batchGroupCount int, inputPrecision, kernelPrecision types.DotGeneralPrecisionType) (*Value, error)
Convolution performs a convolution supporting strides, padding, dilations, feature grouping, and batch grouping.
See description in https://openxla.org/stablehlo/spec#convolution
The parameters strides, paddings, inputDilations, and kernelDilations can be set to nil, and the default (zeros for paddings and ones for the others) will be used.
Note: since the spec mentions that window_reversal will be removed, we didn't include it in the API. If you need it, we can create an alternative API for Convolve with it.
func CountLeadingZeros ¶
CountLeadingZeros implements the corresponding standard unary operation.
func Dot ¶ added in v0.2.0
DotGeneral takes as input lhs (left-hand-side) and rhs (right-hand-side) specifications for a general vector product -- a generalized "Einsum". Each axis can be:
- Just aligned (batch axes), so the output has the same axes as the inputs. The dimensions must match in lhs and rhs.
- Crossed (default), in which case the output is the combination (concatenation) of the dimensions.
- Contracted (contracting axes), where the output does multiply the values and reduce sum those dimensions.
It follows that the resulting dimension number starts with the batch dimension, then the 'lhs' non-contracting/non-batch dimension, and finally the 'rhs' non-contracting/non-batch dimension. It provides the basic means of implementing Einsum.
Because there are optional parameters, this function returns a DotGeneralBuilder that can be further configured. Call DotGeneralBuilder.Done to get the final DotGeneral node.
func DynamicSlice ¶
DynamicSlice extracts a slice from the operand at the startIndices position and the given sliceSizes.
- operand: tensor from where to take the slice. - startIndices: scalar tensors, one per axis of operand: len(startIndices) == operand.Rank(). - sliceSizes: static values and fixed to keep the shape of the output static.
The startIndices are adjusted as follows:
adjustedStartIndices[i] = clamp(0, StartIndices[i], operand.Dimensions[i] - sliceSizes[i])
func DynamicUpdateSlice ¶
DynamicUpdateSlice updates the operand with the values given in update, at the position given by startIndices.
- operand: original value that to be updated. - update: values to "paste" on top of operand, at position startIndices. - startIndices: scalar tensors, one per axis of operand: len(startIndices) == operand.Rank(). - sliceSizes: static values and fixed to keep the shape of the output static.
It returns a value with the same shape as the operand, with the values updated.
The startIndices are adjusted as follows:
adjustedStartIndices[i] = clamp(0, StartIndices[i], operand.Dimensions[i] - update.Dimensions[i])
func Exponential ¶
Exponential implements the corresponding standard unary operation.
func ExponentialMinusOne ¶
ExponentialMinusOne implements the corresponding standard unary operation.
func FFT ¶
FFT calls the XLA FFT operation, which implements {Forward, Inverse} x {Complex, Real} versions. See documentation in https://openxla.org/stablehlo/spec#fft, but more details in XLA page here: https://openxla.org/xla/operation_semantics#fft.
If fftLengths are not given, one is picked for you: based on the last axis dimension for types.FFTForward, types.FFTInverse, and types.FFTForwardReal. And (last_dim-1)*2 for FFTInverseReal.
The underlying Gopjrt implementation for CPU FFT is backed by Eigen's TensorFFT, and for GPU FFT it uses cuFFT.
func Gather ¶
func Gather(operand, startIndices *Value, indexVectorAxis int, offsetOutputAxes, collapsedSliceAxes, operandBatchingAxes, startIndicesBatchingAxes, startIndexMap, sliceSizes []int, indicesAreSorted bool) (*Value, error)
Gather is a powerful but cumbersome Gather operation. Full details in https://openxla.org/stablehlo/spec#gather.
The output of Gather has the same DType of the operand, from where we are pulling the data.
Its output shape will be composed of 2 parts:
- Batch axes: they come from operandBatchingAxes/startIndicesBatchingAxes (they correspond to each other) and from the other axes of startIndices, except the "indexVectorAxis" (usually the last) that is used as the indices into the operand. (*)
- "Offset axes": these are axes that come from the operand, the sizes given by sliceSizes. Notice that if sliceSizes for an axis is 1, and that axis is present in the collapsedSliceAxes list, this axis gets omitted in the output.
So in general output.Rank() = startIndices.Rank() - 1 + len(offsetAxes).
(*) One exception is if indexVectorAxis == startIndices.Rank(), in which case we assume there is an extra virtual axis in startIndices of size 1, in which case output.Rank() = startIndices.Rank() + len(offsetAxes).
(*) One exception is if indexVectorAxis == startIndices.Rank(), in which case we assume there is an extra implicit axis in startIndices of size 1, in which case output.Rank() = startIndices.Rank() + len(offsetAxes).
Arguments:
- operand: the values from where we are gathering. The output DType will follow the operand one.
- startIndices: are the indices we want to gather. The axis pointed by indexVector lists the indices of the slice to be gathered in the operand array (their values are mapped to the axis in the operand according to startIndexMap). All other axes are "batch dimensions" and they will have equivalent axes (same dimensions) in the output.
- indexVectorAxis: which of the axis in startIndices is collected and used as the start index for slices to be gathered in the operand. It is typically the last axis of startIndices, so startIndices.Shape.Rank()-1. There is a special case where indexVectorAxis == startIndices.Rank() in which case we assume there is an extra virtual axis in startIndices of size 1, in which case output.Rank() = startIndices.Rank() + len(offsetAxes).
- offsetOutputAxes: _output_ axes (not the operand's) that will hold the "offset slices", slices that are not collapsed. It points in which position (axis) in the output these slices should show up. The len(offsetOutputAxes) must match the dimension of indexVectorAxis (== startIndices.Dimensions[indexVectorAxis]). Notice all axes in the operand will either become an "offset axis" in the output, of optionally collapsed (or "squeezed") in the output, if included in collapsedSliceAxes. The axes in the output (given in offsetAxes) to the axes in the operand (the axes not present in collapsedSliceAxes) sequentially. One must have Rank(operand) == len(collapsedSliceAxes) + len(offsetAxes) + len(operandBatchingAxes).
- collapsedSliceAxes: _operand_ axes (for which sliceSizes are 1) not to be included in the output. One must have sliceSizes[collapsedSliceAxes[i]] == 1 for all i.
- operandBatchingAxes: operand's batching axes that have corresponding batching axes in the startIndices, and that will also be included in the output. One must have sliceSizes[operandBatchingAxes[i]] == 1 for all i. Also, one must have Rank(operand) == len(operandBatchingAxes) + len(collapsedSliceAxes) + len(offsetOutputAxes).
- startIndicesBatchingAxes: startIndices' batching axes have corresponding batching axes in the operand, and that will also be included in the output.
- startIndexMap: this maps which value in startIndices is used for which axis in the operand, select the slice to be gathered. Notice len(startIndexMap) must match the startIndices.Dimensions[indexVectorAxis]. Also, len(startIndexMap) == len(offsetOutputAxes) -- offsetOutputAxes maps the same axes in the output. E.g.: if startIndices.shape=(2, 3), indexVectorAxis=1, and operand.rank=4 and startIndexMap=[]int{0, 1, 2}, this means each row of the startIndices will point to the first 3 axes (0,1 and 2) in the operand. For those axes in the operand not explicitly set (so if len(startIndexMap) < operand.Rank()), and not part of operandBatchingAxes, the corresponding axis start index is considered to be 0, and one sets the sliceSizes to take the slice one wants (typically the full slice).
- sliceSizes: a size for each operand's axis, so len(sliceSize) = operand.Rank(). once the start index from where to gather is resolved, this defines how much data in each axis to gather. Constraints: sliceSizes[collapsedSliceAxes[i]] == 1, and sliceSizes[operandBatchingAxes[j]] == 1, for all i, j.
- indicesAreSorted: can be set to true if it's guaranteed that startIndices are sorted (in ascending order, after scattering its values according to start_index_map) by the user. This allows for some optimizations in some platforms.
func IsFinite ¶
IsFinite tests whether each element of operand is finite, i.e., if it is not positive nor negative infinity, and it is not NaN. It returns the same shape as the input, but with boolean values where each element is true if and only if the corresponding input element is finite.
func LogPlusOne ¶
LogPlusOne implements the corresponding standard unary operation.
func MultiReduce ¶
func MultiReduce(inputs, initialValues []*Value, reductionFn *Function, axes ...int) ([]*Value, error)
MultiReduce reduces the input along the given axes.
Each resulting value i is initialized with initValues[i] (e.g.: for a sum, it's 0, for a product it is 1), and then each value is combined with it using the reduction function.
The reduction function must be created with Builder.NewClosure. If there are N inputs and initialValues, the reduction function should have a signature (lhs_1, ... lhs_N, rhs_1, ... lhs_N) and output (out_1 ... out_N), where lhs_i and rhs_i are scalars taken from the inputs.
It returns N results for each aggregated value.
See Reduce for a version that accepts a single input.
TODO: promotion of types doesn't seem to be working according to the spec in https://openxla.org/stablehlo/spec#reduce.
func MultiReduceWindow ¶
func MultiReduceWindow(inputs, initialValues []*Value, reductionFn *Function, windowDimensions, strides, inputDilations, windowDilations []int, paddings [][2]int) ([]*Value, error)
MultiReduceWindow reduces the inputs using arbitrary windows around each element.
Each resulting element for inputs[i] is initialized with initValues[i] (e.g.: for a sum, it's 0, for a product it is 1), and then each value is combined with the window around the element using the reduction function.
The reduction function must be created with Builder.NewClosure. If there are N inputs and initialValues, the reduction function should have a signature (lhs_1, ... lhs_N, rhs_1, ... lhs_N) and output (out_1 ... out_N), where lhs_i and rhs_i are scalars.
It returns N results for each aggregated value.
See ReduceWindow for a version that accepts a single input.
If strides is not set, it defaults to the value of windowDimensions -- the stride matches the window size.
TODO: promotion of types doesn't seem to be working according to the spec in
func MultiScatter ¶
func MultiScatter(inputs []*Value, scatterIndices *Value, updates []*Value, updateWindowAxes, insertedWindowAxes []int, inputBatchingAxes, scatterIndicesBatchingAxes []int, indexedInputAxes []int, indexVectorAxis int, indicesAreSorted, uniqueIndices bool, updateComputationFn *Function) ([]*Value, error)
MultiScatter is like Scatter, but takes N inputs and updates, but one only set of indices, and perform the Scatter on all at the same time.
func Pad ¶
Pad x at start, end or interior (interleaved) at arbitrary axes.
It adds padding values around and in-between the elements of x. For each axis:
- paddingStart elements are inserted before the tensor. This value can be negative, in which case elements are removed from the start of the axis.
- paddingEnd elements are appended after the tensor. This value can be negative, in which case elements are removed from the start of the axis.
- paddingInterior elements are inserted between consecutive elements of the tensor. So setting paddingInterior[i]=2 for axis "i" means 2 elements will be inserted between every adjacent pair of elements. paddingInterior can not be negative.
If any of the padding parameters is not given, it is set to 0 for all axes.
The fill value must be a scalar with the same DType as x and determines what value will be used for the padding.
The output shape is defined by:
For each axis i in x: output.Dimensions[i] = paddingStart[i] + x.Dimensions[i] + max((x.Dimensions[i]-1), 0)*paddingInterior[i] + paddingEnd[i]
func RNGBitGenerator ¶ added in v0.2.0
func RNGBitGenerator(state *Value, shape shapes.Shape, algorithm types.RNGBitGeneratorAlgorithm) (newState, values *Value, err error)
RNGBitGenerator generates the given shape filled with random bits. It takes the current random number generator (RNG) state, see RngState or RngStateFromSeed.
It returns the new state of the RNG and the generated values (with random bits) with the given shape.
The state shape depends on the algorithm:
- types.RngDefault: PJRT implementation defined. - types.RngThreeFry: 2xUint64 - types.RngPhilox: 2xUint64 or 3xUint64
func Reduce ¶
Reduce reduces the input along the given axes.
Each resulting value is initialized with initValue (e.g.: for a sum, it's 0, for a product it's 1), and then each value is combined with it using the reduction function.
The reduction function must be created with Builder.NewClosure, and it should take as input scalar values be associative and commutative.
The initialValue and x must have the same DType. This initial dtype must be promotable to the dtype accepted by the reductions function. The result dtype is the same as the output of the reduction function. So one could reduce-sum a 4bit quantized tensor directly into a Float32.
See MultiReduce for a version that accepts multiple inputs and outputs.
func ReduceWindow ¶
func ReduceWindow(input, initialValue *Value, reductionFn *Function, windowDimensions, strides, inputDilations, windowDilations []int, padding [][2]int) (*Value, error)
ReduceWindow reduces the inputs using arbitrary windows around each element.
Each resulting element for input is initialized with initValue (e.g.: for a sum, it's 0, for a product it is 1), and then each value is combined with the window around the element using the reduction function.
The reduction function must be created with Builder.NewClosure. If there are N inputs and initialValues, the reduction function should have a signature `(lhs, rhs) out`, where lhs, rhs and out are scalars.
If strides is not set, it defaults to the value of windowDimensions -- the stride matches the window size.
See MultiReduceWindow for a version that supports reducing multiple inputs at once.
TODO: promotion of types doesn't seem to be working according to the spec in
func Reshape ¶
Reshape the operand to the given shape. The total size of the new shape must match the original shape.
This has no effect on the data, no transposition is performed.
func RoundNearestAfz ¶
RoundNearestAfz implements the corresponding standard unary operation.
func RoundNearestEven ¶
RoundNearestEven implements the corresponding standard unary operation.
func Scatter ¶
func Scatter(input, scatterIndices, updates *Value, updateWindowAxes, insertedWindowAxes []int, inputBatchingAxes, scatterIndicesBatchingAxes []int, indexedInputAxes []int, indexVectorAxis int, indicesAreSorted, uniqueIndices bool, updateComputationFn *Function) (*Value, error)
Scatter returns the input updated with the values of update at the locations pointed by scatterIndices. It allows axes to be used in powerful ways, but it's complex to get right. Full details in https://openxla.org/stablehlo/spec#gather.
The output of Scatter has the same shape and DType of the input.
Batching: while batching axes are only defined for the input and scatterIndices, the batching axes for the updates are inferred from the scatterIndices.
Arguments:
- input: value to be updated in a scattered fashion.
- scatterIndices: indices of the values to be scattered.
- updates: updated values to be scattered at scatterIndices.
- updateWindowAxes: these axes provide the shape of the update window.
- insertedWindowAxes: in the resulting tensor, each axis is either a batch axis, part of the update window (not specified, taken sequentially) or an insertedWindowAxes defined by this argument.
- inputBatchingAxes: axes that are batched with the input.
- scatterIndicesBatchingAxes: axes that are batched with the scatterIndices.
- indexedInputAxes: axes that are indexed by the scatterIndices at axis indexVectorAxis (aka. "scatter_dims_to_operand_dims").
- indexVectorAxis: the axis in scatterIndices that will create a vector of indices on the input where to scatter. This vector of length scatterIndices.Dimensions[indexVectorAxis] will define the index value in the input on the axes defined by indexedInputAxes. E.g.: indexedInputAxes = [0, 1] and indexVectorAxis = 0, scatterIndices = [[0, 1, 2], [3, 4, 5]] will scatter the values from updates[0] at input[0, 3], updates[1] at input[1, 4], and so on. The shape of the scatterIndices is then "[2", :, ...]"
- indicesAreSorted: whether the scatterIndices are sorted.
- uniqueIndices: whether the scatterIndices are unique.
- indicesAreSorted, uniqueIndices: they can be set to true if it's guaranteed that scatterIndices are sorted (in ascending order) and/or unique (no duplicates). This allows for some optimization in some platforms.
- updateComputation: the closure that element-wise combines the current input value and the update value, computing the result. It defines also the data type of the outputs: if the updateComputation inputs and outputs don't match the corresponding DType of their inputs and updates, the values from inputs and updates must be "promotable" to the DType of the updateComputation. Notice it may be called multiple times for some elements if the indices are not unique or the updates' windows overlap.
func Select ¶
Select takes element-wise values from onTrue or onFalse depending on the value of the pred (must be boolean).
The pred must be boolean and can be a scalar or have the same shape as isTrue and isFalse. isTrue and isFalse must have the same shape and dtypes.
func SelectAndScatter ¶
func SelectAndScatter(input, scatterSource, initialValue *Value, selectFn, scatterFn *Function, windowDimensions, strides []int, paddings [][2]int) (*Value, error)
SelectAndScatter performs a ReduceWindow on the input, selecting one value per window (using the selectFn to choose the value), and then aggregating this value into the output (at the same index as the input).
The return result has the same shape as the input, and it is populated with the initialValue.
func ShiftRightArithmetic ¶
ShiftRightArithmetic implements the corresponding standard binary operation.
func ShiftRightLogical ¶
ShiftRightLogical implements the corresponding standard binary operation.
func Slice ¶
Slice extracts a subarray from the input array. The subarray is of the same rank as the input and contains the values inside a bounding box within the input array where the dimensions and indices of the bounding box are given as arguments to the slice operation. The strides set the input stride of the slice in each axis and must be >= 1. It is optional, and if missing, it is assumed to be 1 for every dimension. Examples:
Slice(x={0, 1, 2, 3, 4}, starts={2}, limits={4}, strides=nil) -> {2, 3}
Slice(x={0, 1, 2, 3, 4}, starts={2}, limits={5}, strides={2}) -> {2, 4}
func Transpose ¶
Transpose axes of x.
There should be one value in permutation for each axis in x (len(permutation) == rank(x)).
The output will have: output.Shape.Dimension[ii] = x.Shape.Dimension[permutations[i]].
Source Files
¶
Directories
¶
| Path | Synopsis |
|---|---|
|
internal
|
|
|
cmd/ops_generator
command
|
|
|
optypes
Package optypes defines OpType and lists the supported operations.
|
Package optypes defines OpType and lists the supported operations. |
|
utils
Package utils holds small utility types and functions used internally in stablehlo.
|
Package utils holds small utility types and functions used internally in stablehlo. |
|
Package shapeinference calculates the shape resulting from operations and validates its inputs.
|
Package shapeinference calculates the shape resulting from operations and validates its inputs. |
|
shapes
Package shapes defines Shape and DType and associated tools.
|
Package shapes defines Shape and DType and associated tools. |
|
shardy
Package shardy provides the types needed to define a distributed computation topology.
|
Package shardy provides the types needed to define a distributed computation topology. |