Skip to content

Commit da5297b

Browse files
ai-edge-botcopybara-github
authored andcommitted
Add utility functions for TFL tensors
LiteRT-PiperOrigin-RevId: 748467280
1 parent 7026434 commit da5297b

File tree

4 files changed

+68
-14
lines changed

4 files changed

+68
-14
lines changed

litert/runtime/BUILD

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,13 +170,17 @@ cc_library(
170170
"tfl_utils.h",
171171
],
172172
deps = [
173+
"@com_google_absl//absl/strings:str_format",
173174
"//litert/c:litert_common",
175+
"//litert/c:litert_layout",
174176
"//litert/c:litert_model",
175177
"//litert/cc:litert_detail",
176178
"//litert/cc:litert_element_type",
177179
"//litert/cc:litert_expected",
180+
"//litert/cc:litert_layout",
178181
"//litert/cc:litert_macros",
179182
"//litert/cc:litert_model",
183+
"//litert/cc:litert_tensor_buffer",
180184
"//litert/core/util:tensor_type_util",
181185
"//tflite/c:c_api",
182186
"//tflite/c:c_api_opaque",

litert/runtime/dispatch/dispatch_delegate_kernel.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,7 @@ Expected<void> DispatchDelegateKernel::AllocateTensorBuffersIfNeeded() {
600600
litert_tensor_buffer = new_tensor_buffer.Get();
601601
LITERT_RETURN_IF_ERROR(buffer_context_->RegisterTensorBuffer(
602602
tfl_tensor, std::move(new_tensor_buffer)));
603-
LITERT_ASSIGN_OR_RETURN(auto tfl_tensor_size, GetTensorSize(tfl_tensor));
603+
size_t tfl_tensor_size = TfLiteOpaqueTensorByteSize(tfl_tensor);
604604
tensor_buffer_info.MarkAsMaybeSyncWithCpu(tfl_tensor_size);
605605
}
606606

litert/runtime/tfl_utils.cc

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,21 @@
1717
#include <cstddef>
1818
#include <utility>
1919

20+
#include "absl/strings/str_format.h" // from @com_google_absl
2021
#include "litert/c/litert_common.h"
22+
#include "litert/c/litert_layout.h"
2123
#include "litert/c/litert_model.h"
2224
#include "litert/cc/litert_detail.h"
2325
#include "litert/cc/litert_element_type.h"
2426
#include "litert/cc/litert_expected.h"
27+
#include "litert/cc/litert_layout.h"
2528
#include "litert/cc/litert_macros.h"
2629
#include "litert/cc/litert_model.h"
30+
#include "litert/cc/litert_tensor_buffer.h"
2731
#include "litert/core/util/tensor_type_util.h"
2832
#include "tflite/c/c_api_opaque.h"
2933
#include "tflite/c/c_api_types.h"
34+
#include "tflite/c/common.h"
3035

3136
namespace litert::internal {
3237

@@ -78,27 +83,60 @@ Expected<ElementType> ConvertElementType(TfLiteType tfl_type) {
7883
}
7984
}
8085

81-
Expected<RankedTensorType> ConvertTensorType(
86+
Expected<Layout> ConvertTensorLayout(
8287
const TfLiteOpaqueTensor* tfl_opaque_tensor) {
83-
auto tfl_type = TfLiteOpaqueTensorType(tfl_opaque_tensor);
84-
auto element_type = ConvertElementType(tfl_type);
85-
if (!element_type) {
86-
return Unexpected(element_type.Error());
87-
}
88-
8988
size_t rank = TfLiteOpaqueTensorNumDims(tfl_opaque_tensor);
9089
Dimensions dimensions(rank);
9190
for (size_t i = 0; i < rank; ++i) {
9291
dimensions[i] = TfLiteOpaqueTensorDim(tfl_opaque_tensor, i);
9392
}
93+
return Layout(std::move(dimensions));
94+
}
9495

95-
return RankedTensorType(*element_type, Layout(std::move(dimensions)));
96+
Expected<RankedTensorType> ConvertTensorType(
97+
const TfLiteOpaqueTensor* tfl_opaque_tensor) {
98+
auto tfl_type = TfLiteOpaqueTensorType(tfl_opaque_tensor);
99+
LITERT_ASSIGN_OR_RETURN(auto element_type, ConvertElementType(tfl_type));
100+
LITERT_ASSIGN_OR_RETURN(auto layout, ConvertTensorLayout(tfl_opaque_tensor));
101+
return RankedTensorType(element_type, std::move(layout));
96102
}
97103

98-
Expected<size_t> GetTensorSize(const TfLiteOpaqueTensor* tfl_opaque_tensor) {
104+
Expected<TensorBuffer> CreateHostTensorBufferFromTflTensor(
105+
TfLiteOpaqueContext* tfl_context,
106+
const TfLiteOpaqueTensor* tfl_opaque_tensor) {
99107
LITERT_ASSIGN_OR_RETURN(auto tensor_type,
100108
ConvertTensorType(tfl_opaque_tensor));
101-
return GetNumPackedBytes(static_cast<LiteRtRankedTensorType>(tensor_type));
109+
void* host_mem_addr = TfLiteOpaqueTensorData(tfl_opaque_tensor);
110+
size_t buffer_size = TfLiteOpaqueTensorByteSize(tfl_opaque_tensor);
111+
LITERT_ASSIGN_OR_RETURN(auto tensor_buffer,
112+
TensorBuffer::CreateFromHostMemory(
113+
tensor_type, host_mem_addr, buffer_size));
114+
return tensor_buffer;
115+
}
116+
117+
Expected<void> ResizeTensor(const LiteRtLayout& layout,
118+
TfLiteOpaqueContext* tfl_context,
119+
TfLiteOpaqueTensor* tfl_opaque_tensor) {
120+
// TFL tensors don't support strides.
121+
if (layout.strides) {
122+
return Unexpected(kLiteRtStatusErrorInvalidArgument,
123+
"Unexpected layout with strides");
124+
}
125+
126+
TfLiteIntArray* output_size = TfLiteIntArrayCreate(layout.rank);
127+
for (auto i = 0; i < layout.rank; ++i) {
128+
output_size->data[i] = layout.dimensions[i];
129+
}
130+
if (auto status = TfLiteOpaqueContextResizeTensor(
131+
tfl_context, tfl_opaque_tensor, output_size);
132+
status != kTfLiteOk) {
133+
return Unexpected(
134+
kLiteRtStatusErrorRuntimeFailure,
135+
absl::StrFormat("Failed to resize TFL tensor %s: %d",
136+
TfLiteOpaqueTensorName(tfl_opaque_tensor), status));
137+
}
138+
139+
return {};
102140
}
103141

104142
} // namespace litert::internal

litert/runtime/tfl_utils.h

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,33 @@
1515
#ifndef ODML_LITERT_LITERT_RUNTIME_TFL_UTILS_H_
1616
#define ODML_LITERT_LITERT_RUNTIME_TFL_UTILS_H_
1717

18+
#include "litert/c/litert_layout.h"
1819
#include "litert/cc/litert_expected.h"
20+
#include "litert/cc/litert_layout.h"
1921
#include "litert/cc/litert_model.h"
22+
#include "litert/cc/litert_tensor_buffer.h"
2023

2124
struct TfLiteOpaqueTensor;
2225

2326
namespace litert::internal {
2427

2528
Expected<ElementType> ConvertElementType(TfLiteType tfl_type);
2629

30+
Expected<Layout> ConvertTensorLayout(
31+
const TfLiteOpaqueTensor* tfl_opaque_tensor);
32+
2733
Expected<RankedTensorType> ConvertTensorType(
2834
const TfLiteOpaqueTensor* tfl_opaque_tensor);
2935

30-
// Return the size (in bytes) necessary to store a given TFL tensor's numeric
31-
// data.
32-
Expected<size_t> GetTensorSize(const TfLiteOpaqueTensor* tfl_opaque_tensor);
36+
// Create a TensorBuffer attached to the TFL tensor's data buffer.
37+
Expected<TensorBuffer> CreateHostTensorBufferFromTflTensor(
38+
TfLiteOpaqueContext* tfl_context,
39+
const TfLiteOpaqueTensor* tfl_opaque_tensor);
40+
41+
// Resize a given `tfl_opaque_tensor` based on a given `layout`.
42+
Expected<void> ResizeTensor(const LiteRtLayout& layout,
43+
TfLiteOpaqueContext* tfl_context,
44+
TfLiteOpaqueTensor* tfl_opaque_tensor);
3345

3446
} // namespace litert::internal
3547

0 commit comments

Comments
 (0)