Open
Description
🐛 Describe the bug
The following graph with a bfloat16 relu fails to lower with the following error:
SpecViolationError: These operators are taking Tensor inputs with mismatched dtypes:
Operator: <EdgeOpOverload: aten.relu.default>: schema = aten::relu(Tensor self) -> Tensor with args: {'self': torch.bfloat16, '__ret_0': torch.bfloat16}
stack trace: File "[/var/folders/90/5w9gk0fn4n3g7fw1bvq8r1_m0000gn/T/ipykernel_98835/2855951395.py", line 11](http://localhost:8888/var/folders/90/5w9gk0fn4n3g7fw1bvq8r1_m0000gn/T/ipykernel_98835/2855951395.py#line=10), in forward
return torch.nn.functional.relu(x)
Please make sure the dtypes of the Tensor inputs are the same as the dtypes of the corresponding outputs.
Repro:
import torch
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
from executorch.exir import to_edge_transform_and_lower
from executorch.extension.pybindings.portable_lib import _load_for_executorch_from_buffer
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return torch.nn.functional.relu(x)
model = Model()
inputs = (
torch.randn(8).to(torch.bfloat16),
)
eager_outputs = model(*inputs)
print(f"Eager: {eager_outputs.shape} {eager_outputs}")
lowered = to_edge_transform_and_lower(
torch.export.export(model, inputs),
#partitioner=[CoreMLPartitioner()],
).to_executorch()
et_model = _load_for_executorch_from_buffer(lowered.buffer)
et_outputs = et_model([*inputs])[0]
et_outputs - eager_outputs
Versions
executorch commit 67b6009 (Jun 14)
Metadata
Metadata
Assignees
Labels
Type
Projects
Status
To triage