stablehlo

package module
v0.2.0 Latest Latest
Warning

This package is not in the latest version of its module.

Go to latest
Published: Dec 4, 2025 License: Apache-2.0 Imports: 19 Imported by: 1

README

XLA's StableHLO Builder API for Go

GoDev GitHub Go Report Card TestStatus Slack Sponsor gomlx

[!Note] Discussion in the Slack channel #gomlx (you can join the slack server here)

GoMLX Gopher

StableHLO is an operation set for high-level operations (HLO) in machine learning (ML) models.

It's the portability layer between ML frameworks (targeted for GoMLX, but could be used for others) and ML compilers. It allows for easy support for different vendors, by coupling with XLA's PJRT (*) API for executing StableHLO programs. So many different GPUs and TPUs are supported.

(*) PJRT, which stands for Pluggable JIT Runtime, is an API in the context of XLA (Accelerated Linear Algebra) that provides a unified, cross-platform interface for interacting with different hardware accelerators. StableHLO is the device-independent language to specify the computation, and it also includes APIs to handle buffer (the data) management and optionally distributed execution.

See:

  • StableHLO specification
  • GoMLX: a Go ML framework that supports an XLA (StableHLO+PJRT) backend to efficiently run (or train) ML programs.
  • GoPJRT: a Go wrapper for PJRT C API, capable of executing StableHLO programs, for a lower level API.

Examples

The tests in tests/gopjrt/gopjrt_test.go should serve as simple examples of each operation.

Notice that stablehlo is a low-level API, usually used to build higher-level frameworks (an ML framework like GoMLX, maybe an image manipulation library that uses accelerators like GPUs, some scientific library, etc.), so it's deliberately verbose and requires boilerplate (error handling) everywhere. It sacrifices ergonomics for performance, consistency and stability.

See another example of stablehlo and GoPJRT (to execute the generate StableHLO program) in Mandelbrot mandelbrot.ipynb notebook. It includes some sample StableHLO code, if you are curious.

Status of Operations

Most operations are already implemented. See the list of supported operations (the ones not implemented are in the bottom of the list).

If you need a specific operation, please open an issue.

See also the CHANGELOG.

Dynamic Shapes Support: unbounded dynamism using shape polymorphism only!

In the first version we aim at supporting only unbounded dynamism using shape polymorphism: where axes dimensions are not defined and has no bounds, and where PJRT will be able to dynamically re-instantiate and re-compile the program to a new shape (or re-use a cache).

Other types of dynamism:

  • Unranked dynamism: rank unknown and compile time. Not supported.
  • Data-dependent dynamism: for data-dependent dynamic ops. For instance, if a function returns the indices of all non-zero elements. There is little support for this, so we are not support it yet.

StableHLO replaces GoPJRT's XlaBuilder

With the following advantages:

  • XlaBuilder has become a second-class citizen, so to say, within OpenXLA. And things are moving towards the "MLIR builder" (MLIR is the generic ML Intermediary Language, of which StableHLO is a specialization/extension). So we will eventually need a newer "builder" for Gopjrt.
  • Since PJRT takes StableHLO in plain text format, we can write this entirely in Go, not requiring any extra C/C++ library build.
    • PJRT itself is a C library, but with a relatively small API surface, and for which there are prebuilt distributions available (for Jax). So we can get away without having to manage Bazel issues.
    • The goal is to eventually not require a C compiler to compile gopjrt, and instead use ebitengine/purego do dynamically load PJRT.
    • There are PJRT for different platforms. If we don't need to compile XlaBuilder for them, it makes it more plausible to support them.

The disadvantages:

  • XlaBuilder provided "shape inference." So if I say Add(a, b) the XlaBuilder would tell how to broadcast a and b, and the resulting shape. When we build the StableHLO we have to re-implement this shape inference, not only for the GoPJRT users, but also because the StableHLO language requires the inputs and outputs shapes to be specified in every statement.
    • This is not a disadvantage for the user of this library, since stablehlo does that for you, but it's more work for the library maintainers.
  • This means more maintenance: any updates in the language specification or new ops need to have their shape inference updated accordingly.

The shapeinference sub-package

The same code is also used by GoMLX SimpleGo engine (github.com/gomlx/gomlx/backends/simplego), but we didn't want to create a dependency in either direction: users of Gopjrt may not be interested in GoMLX, and users of GoMLX that don't use the XLA backend wouldn't want a dependency to Gopjrt.

Also, some operations have slightly different nuances.

Documentation

Overview

Package stablehlo helps build a ToStableHLO program (text format) to then be JIT-compiled and executed by PJRT (github.com/gomlx/gopjrt/pjrt).

Among its features:

- Translates an API to rendered (human-readable) ToStableHLO text. - Shape inference: it calculates the output shapes for operations. - Written purely in Go, no C/C++ external dependencies.

It was written as a replacement for `gopjrt/xlabuilder` and attempts to keep a similar or identical interface.

See ToStableHLO documentation and specifications in https://openxla.org/stablehlo/spec

Index

Constants

View Source
const IndentationStep = "  "
View Source
const MainFunctionName = "main"

Variables

This section is empty.

Functions

func BatchNormGradient

func BatchNormGradient(operand, scale, mean, variance, gradOutput *Value, epsilon float32, featureAxis int) (gradOperand *Value, gradScale *Value, gradOffset *Value, err error)

BatchNormGradient calculates the batch normalization gradients with respect to the input, scale, and offset. https://openxla.org/xla/operation_semantics#batchnormgrad

The gradOutput is the adjoint gradient (the "V" in "VJP"), that is, the gradient with respect to the output of the batch normalization.

Based on the paper "Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" (Sergey Ioffe, Christian Szegedy), https://arxiv.org/abs/1502.03167.

func BatchNormTraining

func BatchNormTraining(operand, scale, offset *Value, epsilon float32, featureAxis int) (normalized *Value, batchMean *Value, batchVariance *Value, err error)

BatchNormTraining implements batch normalization for training. See details in https://www.tensorflow.org/xla/operation_semantics#batchnormtraining.

It returns the normalized tensor, the batch mean and variance.

Based on the paper "Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" (Sergey Ioffe, Christian Szegedy), https://arxiv.org/abs/1502.03167.

func ConvertToValidName

func ConvertToValidName(name string) string

ConvertToValidName replaces any characters not in { "0"-"9", "a"-"z", "A-Z", "_" } to a "_", making it a valid name for values and function arguments.

func NormalizeIdentifier

func NormalizeIdentifier(name string) string

NormalizeIdentifier converts the name of an identifier (function name or function input parameter name, etc.) to a valid one: only letters, digits, and underscores are allowed.

Invalid characters are replaced with underscores. If the name starts with a digit, it is prefixed with an underscore.

Types

type Builder

type Builder struct {
	// contains filtered or unexported fields
}

Builder is used to construct a StableHLO program (or "Module") See details in New.

func New

func New(name string) *Builder

New creates a new Builder object holding a computation graph in construction.

From a builder you can create functions. For each function you create operations (ops) one by one, until you defined the desired computation.

You have to define the "main" function for your StableHLO program: you can use Builder.Main to do so, or Builder.NewFunction("main",...), it's the same.

Once you are all set, call Builder.Build and it will return the StableHLO program (or "Module") as a []byte that can be used with PJRT.

See github.com/gomlx/gopjrt for a Go API to PJRT.

func (*Builder) Build

func (b *Builder) Build() ([]byte, error)

Build checks the validity and builds the StableHLO program.

If you want the output of an incomplete program (without the checking), use Builder.Write instead.

func (*Builder) Main

func (b *Builder) Main(inputs ...*Value) *Function

Main creates the main function of the program. It is an alias to Builder.NewFunction("main", inputs...).

The main function is the entry point of the program, and it's the only function that can be called from outside the program.

Every program must have a main function.

Like with NewFunction, you can add new inputs later by calling Function.Input.

func (*Builder) Meshes added in v0.2.0

func (b *Builder) Meshes() []*shardy.DeviceMesh

Meshes returns the meshes configured with WithShardy.

func (*Builder) NewFunction

func (b *Builder) NewFunction(name string, inputs ...*Value) *Function

NewFunction creates a new function and adds it to the program. The function outputs will be determined by the last statement in the function body.

The function name must be unique in the program.

The inputs are the values that the function will receive as arguments. The values are not added to the program, they are just used as inputs.

You can also add new inputs later by calling Function.Input.

The function body is defined by calling ops on the function object.

See Function.

func (*Builder) NewShardingSpec added in v0.2.0

func (b *Builder) NewShardingSpec() *shardy.ShardingSpec

NewShardingSpec creates a new ShardingSpec using the first mesh configured with WithShardy. It returns nil if no mesh was not configured.

This is a shortcut to NewShardingSpecByMeshIx(0).

func (*Builder) NewShardingSpecByMeshIx added in v0.2.0

func (b *Builder) NewShardingSpecByMeshIx(meshIdx int) *shardy.ShardingSpec

NewShardingSpecByMeshIx creates a new ShardingSpec for the meshIdx (the order given by WithShardy).

It may return nil if meshIdx is out of range.

func (*Builder) WithNumPartitions added in v0.1.0

func (b *Builder) WithNumPartitions(n int) *Builder

WithNumPartitions sets the number of partitions (for model parallelism). This is added as an attribute to the StableHLO module.

Consider using WithShardy for distributed computation instead: other forms of distributed (collective) computation across devices are not tested and may not work.

func (*Builder) WithNumReplicas added in v0.1.0

func (b *Builder) WithNumReplicas(n int) *Builder

WithNumReplicas sets the number of replicas (for data parallelism). This is added as an attribute to the StableHLO module.

Consider using WithShardy for distributed computation instead: other forms of distributed (collective) computation across devices are not tested and may not work.

func (*Builder) WithShardy added in v0.2.0

func (b *Builder) WithShardy(meshes ...*shardy.DeviceMesh) *Builder

WithShardy enables distributed computation across the devices selected by the given meshes.

This is the recommended way to do distributed (across devices) computation, and given the inputs with sharded information, Shardy will automatically distribute the computation, without you needing to specify any of the collective operations.

Usually, there is only one meshes. But one can split the devices in different meshes. The meshes overlap the concrete devices used.

See details of XLA Shardy in [1]

[1] https://github.com/openxla/shardy

func (*Builder) Write

func (b *Builder) Write(writer io.Writer) error

Write the StableHLO program (a readable string) to the given writer.

It will write incomplete programs (without a main function or empty statements) without an error to help debugging.

See Builder.Build to check and output the program.

type Computation

type Computation struct {
	Name      string
	StableHLO string
}

Computation holds a rendered computation graph, that can be fed to PJRT. It is created with Builder.Build.

type DotGeneralBuilder

type DotGeneralBuilder struct {
	// contains filtered or unexported fields
}

DotGeneralBuilder is a builder for DotGeneral nodes. See DotGeneral for more details.

func DotGeneral

func DotGeneral(
	lhsOp *Value, lhsContractingAxes, lhsBatchAxes []int,
	rhsOp *Value, rhsContractingAxes, rhsBatchAxes []int) *DotGeneralBuilder

DotGeneral takes as input lhs (left-hand-side) and rhs (right-hand-side) specifications for a general vector product -- a generalized "Einsum". Each axis can be:

  • Just aligned (batch axes), so the output has the same axes as the inputs. The dimensions must match in lhs and rhs.
  • Crossed (default), in which case the output is the combination (concatenation) of the dimensions.
  • Contracted (contracting axes), where the output does multiply the values and reduce sum those dimensions.

It follows that the resulting dimension number starts with the batch dimension, then the 'lhs' non-contracting/non-batch dimension, and finally the 'rhs' non-contracting/non-batch dimension. It provides the basic means of implementing Einsum.

Because there are optional parameters, this function returns a DotGeneralBuilder that can be further configured. Call DotGeneralBuilder.Done to get the final DotGeneral node.

func (*DotGeneralBuilder) Algorithm

Algorithm sets the algorithm settings to use for the dot-general operation.

The default is not to set any of these parameters.

See details in types.DotGeneralAlgorithm.

func (*DotGeneralBuilder) Done

func (b *DotGeneralBuilder) Done() (*Value, error)

Done indicates the end of the DotGeneralBuilder configuration. It checks the validity of the parameters and shapes and returns the final DotGeneral node.

func (*DotGeneralBuilder) OutputDType

func (b *DotGeneralBuilder) OutputDType(dtype dtypes.DType) *DotGeneralBuilder

OutputDType sets the output data type: for input types like BFloat16 one may want to increase the output precision.

func (*DotGeneralBuilder) Precision

func (b *DotGeneralBuilder) Precision(lhsPrecision, rhsPrecision types.DotGeneralPrecisionType) *DotGeneralBuilder

Precision sets the precision of the dot-general operation.

Its default is described as "the fastest calculation, but the least accurate approximation to the original number."

It controls the tradeoff between speed and accuracy for computations on accelerator backends. This can be one of the following (at the moment, the semantics of these enum values are underspecified, but they are planning to address this in #755 -- https://github.com/openxla/stablehlo/issues/755):

type Function

type Function struct {
	Builder *Builder

	// Name of the function. It should not include the "@" prefix.
	Name string

	// Inputs to the function.
	Inputs []*Value

	// Outputs of the function.
	Outputs []*Value

	// Statements in the function body.
	Statements []*Statement

	// Parent of a closure function. It is only set if the function is a closure, and it's the function that created it.
	Parent *Function

	// Returned indicates if the function has a return statement, so it can no longer be changed.
	Returned bool
	// contains filtered or unexported fields
}

Function represents a `func.func` in ToStableHLO.

func (*Function) Closure

func (fn *Function) Closure() *Function

Closure creates an unnamed closure function that can be used as an argument to operations like Reduce, ReduceWindow, ScatterAndUpdate, etc.

After created, the Closure should not be changed. But it can be used multiple times within the same parent function.

The function body is defined by calling ops on the function object, as a usual Function object.

func (*Function) ConstantFromFlatAndDimensions

func (fn *Function) ConstantFromFlatAndDimensions(flat any, dimensions ...int) (*Value, error)

ConstantFromFlatAndDimensions creates a new constant statement from a flat slice with the raw values and the dimensions of the shape.

func (*Function) ConstantFromScalar

func (fn *Function) ConstantFromScalar(value any) (*Value, error)

ConstantFromScalar creates a new constant statement and returns the resulting value.

func (*Function) Input

func (fn *Function) Input(shape shapes.Shape) (*Value, error)

Input creates a new input parameter for a function.

If creating multiple inputs (one at a time), the order matters, since during execution of a compiled function, the input parameters must be given in the same order they were created.

These add to the inputs already created during the function creation.

It picks a default unique name for the input parameter, you can also provide a name with NamedInput.

func (*Function) InputWithAttributes added in v0.2.0

func (fn *Function) InputWithAttributes(shape shapes.Shape, attributes map[string]any) (*Value, error)

InputWithAttributes creates a new input with the given attributes.

func (*Function) InputWithSharding added in v0.2.0

func (fn *Function) InputWithSharding(shape shapes.Shape, shardingSpec *shardy.ShardingSpec) (*Value, error)

InputWithSharding creates a new input with the given sharding specification.

func (*Function) InputWithShardingAndAttributes added in v0.2.0

func (fn *Function) InputWithShardingAndAttributes(shape shapes.Shape, shardingSpec *shardy.ShardingSpec, attributes map[string]any) (*Value, error)

InputWithShardingAndAttributes creates a new input with the given sharding specification and attributes.

func (*Function) Iota

func (fn *Function) Iota(shape shapes.Shape, axis int) (*Value, error)

Iota creates a constant of the given shape with increasing numbers (starting from 0) on the given axis. So Iota([2,2], 1) returns [[0 1][0 1]], while Iota([2,2], 0) returns [[0 0][1 1]].

func (*Function) NamedInput

func (fn *Function) NamedInput(name string, shape shapes.Shape) (*Value, error)

NamedInput creates a new input parameter for a function with the given name -- it must be a unique input name.

The name is passed through ConvertToValidName, which converts any non-digit or ASCII letter to an underscore.

Names with the format "%d" and "arg%d" are reserved for the default input parameters.

Names are used in the StableHLO code and may be helpful for debugging, but otherwise have no impact.

func (*Function) NamedInputWithAttributes added in v0.2.0

func (fn *Function) NamedInputWithAttributes(name string, shape shapes.Shape,
	attributes map[string]any) (*Value, error)

NamedInputWithAttributes creates a new input parameter for a function with the given name and attributes.

func (*Function) NamedInputWithSharding added in v0.2.0

func (fn *Function) NamedInputWithSharding(name string, shape shapes.Shape,
	shardingSpec *shardy.ShardingSpec) (*Value, error)

NamedInputWithSharding creates a new input parameter for a function with the given name -- it must be a unique input name -- and sharding specification for distributed computation.

func (*Function) NamedInputWithShardingAndAttributes added in v0.2.0

func (fn *Function) NamedInputWithShardingAndAttributes(name string, shape shapes.Shape,
	shardingSpec *shardy.ShardingSpec, attributes map[string]any) (*Value, error)

NamedInputWithShardingAndAttributes creates a new input parameter for a function with the given name -- it must be a unique input name -- and sharding specification for distributed computation.

The shardingSpec can be nil: the default is a replicated input across all devices.

The name is passed through ConvertToValidName, which converts any non-digit or ASCII letter to an underscore.

Names with the format "%d" and "arg%d" are reserved for the default input parameters.

Names are used in the StableHLO code and may be helpful for debugging, but otherwise have no impact.

func (*Function) Return

func (fn *Function) Return(values ...*Value) error

Return adds a return statement to the function with the given return values. There must be at least one return value.

There can be only one return statement from a Function, and it must be the last operation of a function.

If you are doing distributed computation, you can use WithReturnShardingSpecs to specify the sharding requirements for each of the return values.

func (*Function) ReturnWithAttributes added in v0.2.0

func (fn *Function) ReturnWithAttributes(values []*Value, attributes []map[string]any) error

ReturnWithAttributes adds a return statement to the function with the given return values and attributes.

func (*Function) ReturnWithShardingAndAttributes added in v0.2.0

func (fn *Function) ReturnWithShardingAndAttributes(values []*Value, shardingSpecs []*shardy.ShardingSpec,
	attributes []map[string]any) error

ReturnWithShardingAndAttributes is a convenience function to call ReturnWithAttributes with the given sharding specifications.

The shardingSpecs slice of ShardingSpecs must have the same length as the values slice. Each ShardingSpec can be nil, in which case the default sharding is replicated across all devices. If shardingSpecs is nil, this behaves just like ReturnWithAttributes.

The attributes slice of maps can be set to nil if there are no attributes.

func (*Function) Write

func (fn *Function) Write(writer io.Writer, indentation string) error

Write the function as StableHLO code, with the given indentation.

type Statement

type Statement struct {
	Builder  *Builder
	Function *Function

	// OpType is the type of the operation.
	OpType optypes.OpType

	// Inputs to the operation.
	Inputs []*Value

	// Attributes of the operation.
	Attributes map[string]any

	// FunctionParameters for statements with operations like Reduce, ReduceWindow, ScatterAndUpdate, etc.
	FunctionParameters      []*Function
	FunctionParametersNames []string

	// Outputs of the operation. It may be nil for operations like func.return.
	Outputs []*Value
}

Statement represents a single operation line in ToStableHLO.

func (*Statement) AddFunctionParameter

func (s *Statement) AddFunctionParameter(name string, inlineFn *Function)

func (*Statement) Write

func (s *Statement) Write(writer io.Writer, indentation string) error

Write writes a string representation of the statement to the given writer.

type Value

type Value struct {
	Attributes map[string]any
	// contains filtered or unexported fields
}

Value represents a value in a StableHLO program, like `%0` or `%arg0`. These values can be inputs, outputs or intermediary values of functions.

It is always associated with a function (where it's being used) and must be uniquely identified by a string with digits '0'-'9', 'A'-'Z', 'a'-'z' or '_'.

For inlined functions (for instance, the one passed to a Reduce operation), the names cannot clash with the parent function name (!?). But the names can be reused in different inline functions.

It also carries its shape information.

func Abs

func Abs(operand *Value) (*Value, error)

Abs implements the corresponding standard unary operation.

func Add

func Add(lhs, rhs *Value) (*Value, error)

Add implements the corresponding standard binary operation.

func AllGather added in v0.1.0

func AllGather(operand *Value, replicaGroups [][]int, allGatherDim int, config ...*types.CollectiveConfig) (*Value, error)

AllGather concatenates the operand from each replica along a specified dimension.

  • operand: The tensor from the *local* replica to be gathered.
  • replicaGroups: A 2D array defining the communicating device groups.
  • allGatherDim: The dimension along which to concatenate the operands.
  • config: Optional configuration of the channels to be used.

Consider using Builder.WithShardy for distributed computation instead: other forms of distributed (collective) computation across devices are not tested and may not work.

func AllReduce added in v0.1.0

func AllReduce(operands []*Value, replicaGroups [][]int, computation *Function, config ...*types.CollectiveConfig) (
	[]*Value, error)

AllReduce performs a distributed reduce operation across replicas. It is a distributed version of Reduce.

  • operands: The tensors from the *local* replica to be reduced.
  • replicaGroups: A 2D array defining the communicating device groups, e.g., `[[0, 1, 2, 3]]`.
  • computation: A closure function that defines the reduction operation (e.g., SUM). It must take two scalar inputs for each operand's dtype and return one scalar output of the same dtype.
  • replicaGroups: A 2D array defining the communicating device groups. For standard data parallelism, this is typically a single group with all the replica numbers -- notice it's not the device numbers by the replica numbers (there is an indirection). Except if the config sets UseGlobalDeviceIDs, in which case they are interpreted as device numbers. E.g., `[[0, 1, 2, 3]]`.
  • config: Optional configuration of the channels to be used. This is not needed for SPMD programs.

Consider using Builder.WithShardy for distributed computation instead: other forms of distributed (collective) computation across devices are not tested and may not work.

func AllToAll added in v0.1.0

func AllToAll(operand *Value, replicaGroups [][]int, splitDimension, concatDimension, splitCount int, config ...*types.CollectiveConfig) (*Value, error)

AllToAll splits the operand along a specified dimension and scatters the chunks to all replicas, where they are concatenated back together.

  • operand: The tensor from the *local* replica.
  • replicaGroups: A 2D array defining the communicating device groups.
  • splitDimension: The dimension along which to split the operand.
  • concatDimension: The dimension along which to concatenate the received chunks.
  • splitCount: The number of chunks to split the operand into. This must match the size of the replica groups.
  • config: Optional configuration of the channels to be used.

Consider using Builder.WithShardy for distributed computation instead: other forms of distributed (collective) computation across devices are not tested and may not work.

func And

func And(lhs, rhs *Value) (*Value, error)

And implements the corresponding standard binary operation.

func Atan2

func Atan2(lhs, rhs *Value) (*Value, error)

Atan2 implements the corresponding standard binary operation.

func BatchNormInference

func BatchNormInference(operand, scale, offset, mean, variance *Value, epsilon float32, featureAxis int) (*Value, error)

BatchNormInference implements batch normalization for inference. See details in https://www.tensorflow.org/xla/operation_semantics#batchnorminference.

Based on the paper "Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift" (Sergey Ioffe, Christian Szegedy), https://arxiv.org/abs/1502.03167.

func BitcastConvert

func BitcastConvert(operand *Value, targetDtype dtypes.DType) (*Value, error)

BitcastConvert performs an elementwise bit-cast operation from a dtype to another dtype.

The Bitcast doesn't "convert", rather it just reinterprets the bits from x.DType() to the targetDType.

If x.DType() and targetDType use the same number of bytes (targetDType.Size() == x.DType().Size()), the dimensions are not changed, simply the dtype is changed.

If targetDType.Size() > x.DType().Size(), it requires x last axis to have a dimension of targetDType.Size() / x.DType().Size(), and the returned shape will trim the last axis.

If targetDType.Size() < x.DType().Size(), the returned shape will have an extra axis in the end, with dimension of x.DType().Size() / targetDType.Size().

E.g: Bitcast([1]uint32{0xdeadbeef}, dtypes.UInt16) -> [1][2]uint16{{0xbeef, 0xdead}} // Little-endian encoding.

func BroadcastInDim

func BroadcastInDim(operand *Value, target shapes.Shape, axesMapping []int) (*Value, error)

BroadcastInDim broadcasts dimensions from the operand to the target shape. It can also transpose axes and add new ones.

The axesMapping should have one value per operand axes. It maps the axes from the operand to the corresponding value on the target shape.

func Cbrt

func Cbrt(operand *Value) (*Value, error)

Cbrt implements the corresponding standard unary operation.

func Ceil

func Ceil(operand *Value) (*Value, error)

Ceil implements the corresponding standard unary operation.

func Clamp

func Clamp(min, x, max *Value) (*Value, error)

Clamp returns the minimum(maximum(x, min), max).

The values max and min can either be a scalar or have the same shape as x.

Clamp is not defined for booleans or complex numbers (the semantics would not be clear).

Note: the order of the arguments in StableHLO is different from most ML libraries.

func CollectiveBroadcast added in v0.1.0

func CollectiveBroadcast(operand *Value, replicaGroups [][]int, config ...*types.CollectiveConfig) (*Value, error)

CollectiveBroadcast broadcasts the value from the first replica (in each group) to all others. The returned shape is the same as the source. Devices not included in any replica group will return zeros as their output (with the same shape as the input).

  • operand: The tensor to be broadcasted. In an SPMD setup, this op will be called on all replicas, but only the operand from the source device (the first device in the replica_group) will be used.
  • replicaGroups: A 2D array defining the communicating device groups. For standard data parallelism, this is typically a single group with all the replica numbers -- notice it's not the device numbers by the replica numbers (there is an indirection). Except if the config sets UseGlobalDeviceIDs, in which case they are interpreted as device numbers. E.g., `[[0, 1, 2, 3]]`.
  • config: Optional configuration of the channels to be used. This is shouldn't be used for SPMD programs.

Consider using Builder.WithShardy for distributed computation instead: other forms of distributed (collective) computation across devices are not tested and may not work.

func CollectivePermute added in v0.1.0

func CollectivePermute(operand *Value, sourceTargetPairs [][2]int, config ...*types.CollectiveConfig) (*Value, error)

CollectivePermute sends the operand from a source replica to a target replica.

  • operand: The tensor from the *local* replica.
  • sourceTargetPairs: A 2D array where each inner array is a `[source, target]` pair of replica IDs.
  • config: Optional configuration of the channels to be used.

Consider using Builder.WithShardy for distributed computation instead: other forms of distributed (collective) computation across devices are not tested and may not work.

func Compare

func Compare(lhs, rhs *Value, direction types.ComparisonDirection, compareType types.ComparisonType) (*Value, error)

Compare implements the corresponding standard binary operation.

For boolean data types (dtypes.Bool) use the types.CompareUnsigned type.

func Complex

func Complex(real, imag *Value) (*Value, error)

Complex returns the complex value by concatenating the real and imaginary parts element-wise.

func Concatenate

func Concatenate(axis int, operands ...*Value) (*Value, error)

Concatenate operands on the given axis.

All axes that are not being concatenated must match dimensions, except on the axes being concatenated. It doesn't work with scalars -- use ExpandAxes. If there is only one operand, it is returned and this is a no-op.

func Convert

func Convert(x *Value, dtype dtypes.DType) (*Value, error)

Convert x to the given dtype.

For boolean to numeric conversions, false becomes 0 and true 1.

For complex to non-complex conversions, the imaginary part is discarded (or set to 0).

Currently, it doesn't work for quantized to/from regular tensors. Use UniformQuantize and UniformDequantize for that.

func Convolution

func Convolution(input, kernel *Value,
	strides []int, paddings [][2]int, inputDilations, kernelDilations []int,
	inputBatchAxis, inputChannelsAxis int, inputSpatialAxes []int,
	kernelInputChannelsAxis, kernelOutputChannelsAxis int, kernelSpatialAxes []int,
	outputBatchAxis, outputChannelsAxis int, outputSpatialAxes []int,
	channelGroupCount, batchGroupCount int,
	inputPrecision, kernelPrecision types.DotGeneralPrecisionType) (*Value, error)

Convolution performs a convolution supporting strides, padding, dilations, feature grouping, and batch grouping.

See description in https://openxla.org/stablehlo/spec#convolution

The parameters strides, paddings, inputDilations, and kernelDilations can be set to nil, and the default (zeros for paddings and ones for the others) will be used.

Note: since the spec mentions that window_reversal will be removed, we didn't include it in the API. If you need it, we can create an alternative API for Convolve with it.

func Cosine

func Cosine(operand *Value) (*Value, error)

Cosine implements the corresponding standard unary operation.

func CountLeadingZeros

func CountLeadingZeros(operand *Value) (*Value, error)

CountLeadingZeros implements the corresponding standard unary operation.

func Divide

func Divide(lhs, rhs *Value) (*Value, error)

Divide implements the corresponding standard binary operation.

func Dot added in v0.2.0

func Dot(lhs, rhs *Value) (*Value, error)

DotGeneral takes as input lhs (left-hand-side) and rhs (right-hand-side) specifications for a general vector product -- a generalized "Einsum". Each axis can be:

  • Just aligned (batch axes), so the output has the same axes as the inputs. The dimensions must match in lhs and rhs.
  • Crossed (default), in which case the output is the combination (concatenation) of the dimensions.
  • Contracted (contracting axes), where the output does multiply the values and reduce sum those dimensions.

It follows that the resulting dimension number starts with the batch dimension, then the 'lhs' non-contracting/non-batch dimension, and finally the 'rhs' non-contracting/non-batch dimension. It provides the basic means of implementing Einsum.

Because there are optional parameters, this function returns a DotGeneralBuilder that can be further configured. Call DotGeneralBuilder.Done to get the final DotGeneral node.

func DynamicSlice

func DynamicSlice(operand *Value, startIndices []*Value, sliceSizes []int) (*Value, error)

DynamicSlice extracts a slice from the operand at the startIndices position and the given sliceSizes.

- operand: tensor from where to take the slice. - startIndices: scalar tensors, one per axis of operand: len(startIndices) == operand.Rank(). - sliceSizes: static values and fixed to keep the shape of the output static.

The startIndices are adjusted as follows:

adjustedStartIndices[i] = clamp(0, StartIndices[i], operand.Dimensions[i] - sliceSizes[i])

func DynamicUpdateSlice

func DynamicUpdateSlice(operand, update *Value, startIndices []*Value) (*Value, error)

DynamicUpdateSlice updates the operand with the values given in update, at the position given by startIndices.

- operand: original value that to be updated. - update: values to "paste" on top of operand, at position startIndices. - startIndices: scalar tensors, one per axis of operand: len(startIndices) == operand.Rank(). - sliceSizes: static values and fixed to keep the shape of the output static.

It returns a value with the same shape as the operand, with the values updated.

The startIndices are adjusted as follows:

adjustedStartIndices[i] = clamp(0, StartIndices[i], operand.Dimensions[i] - update.Dimensions[i])

func Erf

func Erf(operand *Value) (*Value, error)

Erf implements the corresponding standard unary operation.

func Exponential

func Exponential(operand *Value) (*Value, error)

Exponential implements the corresponding standard unary operation.

func ExponentialMinusOne

func ExponentialMinusOne(operand *Value) (*Value, error)

ExponentialMinusOne implements the corresponding standard unary operation.

func FFT

func FFT(x *Value, fftType types.FFTType, fftLength ...int) (*Value, error)

FFT calls the XLA FFT operation, which implements {Forward, Inverse} x {Complex, Real} versions. See documentation in https://openxla.org/stablehlo/spec#fft, but more details in XLA page here: https://openxla.org/xla/operation_semantics#fft.

If fftLengths are not given, one is picked for you: based on the last axis dimension for types.FFTForward, types.FFTInverse, and types.FFTForwardReal. And (last_dim-1)*2 for FFTInverseReal.

The underlying Gopjrt implementation for CPU FFT is backed by Eigen's TensorFFT, and for GPU FFT it uses cuFFT.

func Floor

func Floor(operand *Value) (*Value, error)

Floor implements the corresponding standard unary operation.

func Gather

func Gather(operand, startIndices *Value, indexVectorAxis int,
	offsetOutputAxes, collapsedSliceAxes, operandBatchingAxes,
	startIndicesBatchingAxes, startIndexMap,
	sliceSizes []int, indicesAreSorted bool) (*Value, error)

Gather is a powerful but cumbersome Gather operation. Full details in https://openxla.org/stablehlo/spec#gather.

The output of Gather has the same DType of the operand, from where we are pulling the data.

Its output shape will be composed of 2 parts:

  • Batch axes: they come from operandBatchingAxes/startIndicesBatchingAxes (they correspond to each other) and from the other axes of startIndices, except the "indexVectorAxis" (usually the last) that is used as the indices into the operand. (*)
  • "Offset axes": these are axes that come from the operand, the sizes given by sliceSizes. Notice that if sliceSizes for an axis is 1, and that axis is present in the collapsedSliceAxes list, this axis gets omitted in the output.

So in general output.Rank() = startIndices.Rank() - 1 + len(offsetAxes).

(*) One exception is if indexVectorAxis == startIndices.Rank(), in which case we assume there is an extra virtual axis in startIndices of size 1, in which case output.Rank() = startIndices.Rank() + len(offsetAxes).

(*) One exception is if indexVectorAxis == startIndices.Rank(), in which case we assume there is an extra implicit axis in startIndices of size 1, in which case output.Rank() = startIndices.Rank() + len(offsetAxes).

Arguments:

  • operand: the values from where we are gathering. The output DType will follow the operand one.
  • startIndices: are the indices we want to gather. The axis pointed by indexVector lists the indices of the slice to be gathered in the operand array (their values are mapped to the axis in the operand according to startIndexMap). All other axes are "batch dimensions" and they will have equivalent axes (same dimensions) in the output.
  • indexVectorAxis: which of the axis in startIndices is collected and used as the start index for slices to be gathered in the operand. It is typically the last axis of startIndices, so startIndices.Shape.Rank()-1. There is a special case where indexVectorAxis == startIndices.Rank() in which case we assume there is an extra virtual axis in startIndices of size 1, in which case output.Rank() = startIndices.Rank() + len(offsetAxes).
  • offsetOutputAxes: _output_ axes (not the operand's) that will hold the "offset slices", slices that are not collapsed. It points in which position (axis) in the output these slices should show up. The len(offsetOutputAxes) must match the dimension of indexVectorAxis (== startIndices.Dimensions[indexVectorAxis]). Notice all axes in the operand will either become an "offset axis" in the output, of optionally collapsed (or "squeezed") in the output, if included in collapsedSliceAxes. The axes in the output (given in offsetAxes) to the axes in the operand (the axes not present in collapsedSliceAxes) sequentially. One must have Rank(operand) == len(collapsedSliceAxes) + len(offsetAxes) + len(operandBatchingAxes).
  • collapsedSliceAxes: _operand_ axes (for which sliceSizes are 1) not to be included in the output. One must have sliceSizes[collapsedSliceAxes[i]] == 1 for all i.
  • operandBatchingAxes: operand's batching axes that have corresponding batching axes in the startIndices, and that will also be included in the output. One must have sliceSizes[operandBatchingAxes[i]] == 1 for all i. Also, one must have Rank(operand) == len(operandBatchingAxes) + len(collapsedSliceAxes) + len(offsetOutputAxes).
  • startIndicesBatchingAxes: startIndices' batching axes have corresponding batching axes in the operand, and that will also be included in the output.
  • startIndexMap: this maps which value in startIndices is used for which axis in the operand, select the slice to be gathered. Notice len(startIndexMap) must match the startIndices.Dimensions[indexVectorAxis]. Also, len(startIndexMap) == len(offsetOutputAxes) -- offsetOutputAxes maps the same axes in the output. E.g.: if startIndices.shape=(2, 3), indexVectorAxis=1, and operand.rank=4 and startIndexMap=[]int{0, 1, 2}, this means each row of the startIndices will point to the first 3 axes (0,1 and 2) in the operand. For those axes in the operand not explicitly set (so if len(startIndexMap) < operand.Rank()), and not part of operandBatchingAxes, the corresponding axis start index is considered to be 0, and one sets the sliceSizes to take the slice one wants (typically the full slice).
  • sliceSizes: a size for each operand's axis, so len(sliceSize) = operand.Rank(). once the start index from where to gather is resolved, this defines how much data in each axis to gather. Constraints: sliceSizes[collapsedSliceAxes[i]] == 1, and sliceSizes[operandBatchingAxes[j]] == 1, for all i, j.
  • indicesAreSorted: can be set to true if it's guaranteed that startIndices are sorted (in ascending order, after scattering its values according to start_index_map) by the user. This allows for some optimizations in some platforms.

func Imag

func Imag(complex *Value) (*Value, error)

Imag returns the real part of the complex value.

func IsFinite

func IsFinite(x *Value) (*Value, error)

IsFinite tests whether each element of operand is finite, i.e., if it is not positive nor negative infinity, and it is not NaN. It returns the same shape as the input, but with boolean values where each element is true if and only if the corresponding input element is finite.

func Log

func Log(operand *Value) (*Value, error)

Log implements the corresponding standard unary operation.

func LogPlusOne

func LogPlusOne(operand *Value) (*Value, error)

LogPlusOne implements the corresponding standard unary operation.

func Logistic

func Logistic(operand *Value) (*Value, error)

Logistic implements the corresponding standard unary operation.

func Maximum

func Maximum(lhs, rhs *Value) (*Value, error)

Maximum implements the corresponding standard binary operation.

func Minimum

func Minimum(lhs, rhs *Value) (*Value, error)

Minimum implements the corresponding standard binary operation.

func MultiReduce

func MultiReduce(inputs, initialValues []*Value, reductionFn *Function, axes ...int) ([]*Value, error)

MultiReduce reduces the input along the given axes.

Each resulting value i is initialized with initValues[i] (e.g.: for a sum, it's 0, for a product it is 1), and then each value is combined with it using the reduction function.

The reduction function must be created with Builder.NewClosure. If there are N inputs and initialValues, the reduction function should have a signature (lhs_1, ... lhs_N, rhs_1, ... lhs_N) and output (out_1 ... out_N), where lhs_i and rhs_i are scalars taken from the inputs.

It returns N results for each aggregated value.

See Reduce for a version that accepts a single input.

TODO: promotion of types doesn't seem to be working according to the spec in https://openxla.org/stablehlo/spec#reduce.

func MultiReduceWindow

func MultiReduceWindow(inputs, initialValues []*Value, reductionFn *Function,
	windowDimensions, strides, inputDilations, windowDilations []int,
	paddings [][2]int) ([]*Value, error)

MultiReduceWindow reduces the inputs using arbitrary windows around each element.

Each resulting element for inputs[i] is initialized with initValues[i] (e.g.: for a sum, it's 0, for a product it is 1), and then each value is combined with the window around the element using the reduction function.

The reduction function must be created with Builder.NewClosure. If there are N inputs and initialValues, the reduction function should have a signature (lhs_1, ... lhs_N, rhs_1, ... lhs_N) and output (out_1 ... out_N), where lhs_i and rhs_i are scalars.

It returns N results for each aggregated value.

See ReduceWindow for a version that accepts a single input.

If strides is not set, it defaults to the value of windowDimensions -- the stride matches the window size.

TODO: promotion of types doesn't seem to be working according to the spec in

func MultiScatter

func MultiScatter(inputs []*Value, scatterIndices *Value, updates []*Value,
	updateWindowAxes, insertedWindowAxes []int,
	inputBatchingAxes, scatterIndicesBatchingAxes []int,
	indexedInputAxes []int, indexVectorAxis int,
	indicesAreSorted, uniqueIndices bool,
	updateComputationFn *Function) ([]*Value, error)

MultiScatter is like Scatter, but takes N inputs and updates, but one only set of indices, and perform the Scatter on all at the same time.

func Multiply

func Multiply(lhs, rhs *Value) (*Value, error)

Multiply implements the corresponding standard binary operation.

func Negate

func Negate(operand *Value) (*Value, error)

Negate implements the corresponding standard unary operation.

func Not

func Not(operand *Value) (*Value, error)

Not implements the corresponding standard unary operation.

func Or

func Or(lhs, rhs *Value) (*Value, error)

Or implements the corresponding standard binary operation.

func Pad

func Pad(x, fill *Value, paddingStart, paddingEnd, paddingInterior []int) (*Value, error)

Pad x at start, end or interior (interleaved) at arbitrary axes.

It adds padding values around and in-between the elements of x. For each axis:

  • paddingStart elements are inserted before the tensor. This value can be negative, in which case elements are removed from the start of the axis.
  • paddingEnd elements are appended after the tensor. This value can be negative, in which case elements are removed from the start of the axis.
  • paddingInterior elements are inserted between consecutive elements of the tensor. So setting paddingInterior[i]=2 for axis "i" means 2 elements will be inserted between every adjacent pair of elements. paddingInterior can not be negative.

If any of the padding parameters is not given, it is set to 0 for all axes.

The fill value must be a scalar with the same DType as x and determines what value will be used for the padding.

The output shape is defined by:

For each axis i in x:
output.Dimensions[i] = paddingStart[i] + x.Dimensions[i] + max((x.Dimensions[i]-1), 0)*paddingInterior[i] + paddingEnd[i]

func Popcnt

func Popcnt(operand *Value) (*Value, error)

Popcnt implements the corresponding standard unary operation.

func Power

func Power(lhs, rhs *Value) (*Value, error)

Power implements the corresponding standard binary operation.

func RNGBitGenerator added in v0.2.0

func RNGBitGenerator(state *Value, shape shapes.Shape, algorithm types.RNGBitGeneratorAlgorithm) (newState, values *Value, err error)

RNGBitGenerator generates the given shape filled with random bits. It takes the current random number generator (RNG) state, see RngState or RngStateFromSeed.

It returns the new state of the RNG and the generated values (with random bits) with the given shape.

The state shape depends on the algorithm:

- types.RngDefault: PJRT implementation defined. - types.RngThreeFry: 2xUint64 - types.RngPhilox: 2xUint64 or 3xUint64

func Real

func Real(complex *Value) (*Value, error)

Real returns the real part of the complex value.

func Reduce

func Reduce(x, initialValue *Value, reductionFn *Function, axes ...int) (*Value, error)

Reduce reduces the input along the given axes.

Each resulting value is initialized with initValue (e.g.: for a sum, it's 0, for a product it's 1), and then each value is combined with it using the reduction function.

The reduction function must be created with Builder.NewClosure, and it should take as input scalar values be associative and commutative.

The initialValue and x must have the same DType. This initial dtype must be promotable to the dtype accepted by the reductions function. The result dtype is the same as the output of the reduction function. So one could reduce-sum a 4bit quantized tensor directly into a Float32.

See MultiReduce for a version that accepts multiple inputs and outputs.

func ReduceWindow

func ReduceWindow(input, initialValue *Value, reductionFn *Function,
	windowDimensions, strides, inputDilations, windowDilations []int,
	padding [][2]int) (*Value, error)

ReduceWindow reduces the inputs using arbitrary windows around each element.

Each resulting element for input is initialized with initValue (e.g.: for a sum, it's 0, for a product it is 1), and then each value is combined with the window around the element using the reduction function.

The reduction function must be created with Builder.NewClosure. If there are N inputs and initialValues, the reduction function should have a signature `(lhs, rhs) out`, where lhs, rhs and out are scalars.

If strides is not set, it defaults to the value of windowDimensions -- the stride matches the window size.

See MultiReduceWindow for a version that supports reducing multiple inputs at once.

TODO: promotion of types doesn't seem to be working according to the spec in

func Remainder

func Remainder(lhs, rhs *Value) (*Value, error)

Remainder implements the corresponding standard binary operation.

func Reshape

func Reshape(operand *Value, shape shapes.Shape) (*Value, error)

Reshape the operand to the given shape. The total size of the new shape must match the original shape.

This has no effect on the data, no transposition is performed.

func Reverse

func Reverse(x *Value, axes ...int) (*Value, error)

Reverse axes of x.

E.g.: Reverse([1, 2, 3], axes=0) -> [3, 2, 1]

func RoundNearestAfz

func RoundNearestAfz(operand *Value) (*Value, error)

RoundNearestAfz implements the corresponding standard unary operation.

func RoundNearestEven

func RoundNearestEven(operand *Value) (*Value, error)

RoundNearestEven implements the corresponding standard unary operation.

func Rsqrt

func Rsqrt(operand *Value) (*Value, error)

Rsqrt implements the corresponding standard unary operation.

func Scatter

func Scatter(input, scatterIndices, updates *Value,
	updateWindowAxes, insertedWindowAxes []int,
	inputBatchingAxes, scatterIndicesBatchingAxes []int,
	indexedInputAxes []int, indexVectorAxis int,
	indicesAreSorted, uniqueIndices bool,
	updateComputationFn *Function) (*Value, error)

Scatter returns the input updated with the values of update at the locations pointed by scatterIndices. It allows axes to be used in powerful ways, but it's complex to get right. Full details in https://openxla.org/stablehlo/spec#gather.

The output of Scatter has the same shape and DType of the input.

Batching: while batching axes are only defined for the input and scatterIndices, the batching axes for the updates are inferred from the scatterIndices.

Arguments:

  • input: value to be updated in a scattered fashion.
  • scatterIndices: indices of the values to be scattered.
  • updates: updated values to be scattered at scatterIndices.
  • updateWindowAxes: these axes provide the shape of the update window.
  • insertedWindowAxes: in the resulting tensor, each axis is either a batch axis, part of the update window (not specified, taken sequentially) or an insertedWindowAxes defined by this argument.
  • inputBatchingAxes: axes that are batched with the input.
  • scatterIndicesBatchingAxes: axes that are batched with the scatterIndices.
  • indexedInputAxes: axes that are indexed by the scatterIndices at axis indexVectorAxis (aka. "scatter_dims_to_operand_dims").
  • indexVectorAxis: the axis in scatterIndices that will create a vector of indices on the input where to scatter. This vector of length scatterIndices.Dimensions[indexVectorAxis] will define the index value in the input on the axes defined by indexedInputAxes. E.g.: indexedInputAxes = [0, 1] and indexVectorAxis = 0, scatterIndices = [[0, 1, 2], [3, 4, 5]] will scatter the values from updates[0] at input[0, 3], updates[1] at input[1, 4], and so on. The shape of the scatterIndices is then "[2", :, ...]"
  • indicesAreSorted: whether the scatterIndices are sorted.
  • uniqueIndices: whether the scatterIndices are unique.
  • indicesAreSorted, uniqueIndices: they can be set to true if it's guaranteed that scatterIndices are sorted (in ascending order) and/or unique (no duplicates). This allows for some optimization in some platforms.
  • updateComputation: the closure that element-wise combines the current input value and the update value, computing the result. It defines also the data type of the outputs: if the updateComputation inputs and outputs don't match the corresponding DType of their inputs and updates, the values from inputs and updates must be "promotable" to the DType of the updateComputation. Notice it may be called multiple times for some elements if the indices are not unique or the updates' windows overlap.

func Select

func Select(pred, onTrue, onFalse *Value) (*Value, error)

Select takes element-wise values from onTrue or onFalse depending on the value of the pred (must be boolean).

The pred must be boolean and can be a scalar or have the same shape as isTrue and isFalse. isTrue and isFalse must have the same shape and dtypes.

func SelectAndScatter

func SelectAndScatter(input, scatterSource, initialValue *Value,
	selectFn, scatterFn *Function,
	windowDimensions, strides []int, paddings [][2]int) (*Value, error)

SelectAndScatter performs a ReduceWindow on the input, selecting one value per window (using the selectFn to choose the value), and then aggregating this value into the output (at the same index as the input).

The return result has the same shape as the input, and it is populated with the initialValue.

func ShiftLeft

func ShiftLeft(lhs, rhs *Value) (*Value, error)

ShiftLeft implements the corresponding standard binary operation.

func ShiftRightArithmetic

func ShiftRightArithmetic(lhs, rhs *Value) (*Value, error)

ShiftRightArithmetic implements the corresponding standard binary operation.

func ShiftRightLogical

func ShiftRightLogical(lhs, rhs *Value) (*Value, error)

ShiftRightLogical implements the corresponding standard binary operation.

func Sign

func Sign(operand *Value) (*Value, error)

Sign implements the corresponding standard unary operation.

func Sine

func Sine(operand *Value) (*Value, error)

Sine implements the corresponding standard unary operation.

func Slice

func Slice(x *Value, starts, limits, strides []int) (*Value, error)

Slice extracts a subarray from the input array. The subarray is of the same rank as the input and contains the values inside a bounding box within the input array where the dimensions and indices of the bounding box are given as arguments to the slice operation. The strides set the input stride of the slice in each axis and must be >= 1. It is optional, and if missing, it is assumed to be 1 for every dimension. Examples:

Slice(x={0, 1, 2, 3, 4}, starts={2}, limits={4}, strides=nil) -> {2, 3}
Slice(x={0, 1, 2, 3, 4}, starts={2}, limits={5}, strides={2}) -> {2, 4}

func Sqrt

func Sqrt(operand *Value) (*Value, error)

Sqrt implements the corresponding standard unary operation.

func Subtract

func Subtract(lhs, rhs *Value) (*Value, error)

Subtract implements the corresponding standard binary operation.

func Tan

func Tan(operand *Value) (*Value, error)

Tan implements the corresponding standard unary operation.

func Tanh

func Tanh(operand *Value) (*Value, error)

Tanh implements the corresponding standard unary operation.

func Transpose

func Transpose(x *Value, permutation ...int) (*Value, error)

Transpose axes of x.

There should be one value in permutation for each axis in x (len(permutation) == rank(x)).

The output will have: output.Shape.Dimension[ii] = x.Shape.Dimension[permutations[i]].

func Xor

func Xor(lhs, rhs *Value) (*Value, error)

Xor implements the corresponding standard binary operation.

func (*Value) Shape

func (v *Value) Shape() shapes.Shape

Shape returns the shape of the value.

func (*Value) String

func (v *Value) String() string

String implements fmt.Stringer.

func (*Value) Write

func (v *Value) Write(w io.Writer, indentation string) error

Write writes the value in ToStableHLO text format to the given writer.

Directories

Path Synopsis
internal
optypes
Package optypes defines OpType and lists the supported operations.
Package optypes defines OpType and lists the supported operations.
utils
Package utils holds small utility types and functions used internally in stablehlo.
Package utils holds small utility types and functions used internally in stablehlo.
Package shapeinference calculates the shape resulting from operations and validates its inputs.
Package shapeinference calculates the shape resulting from operations and validates its inputs.
shapes
Package shapes defines Shape and DType and associated tools.
Package shapes defines Shape and DType and associated tools.
shardy
Package shardy provides the types needed to define a distributed computation topology.
Package shardy provides the types needed to define a distributed computation topology.

Jump to

Keyboard shortcuts

? : This menu
/ : Search site
f or F : Jump to
y or Y : Canonical URL