|
17 | 17 | #include <cstddef>
|
18 | 18 | #include <utility>
|
19 | 19 |
|
| 20 | +#include "absl/strings/str_format.h" // from @com_google_absl |
20 | 21 | #include "litert/c/litert_common.h"
|
| 22 | +#include "litert/c/litert_layout.h" |
21 | 23 | #include "litert/c/litert_model.h"
|
22 | 24 | #include "litert/cc/litert_detail.h"
|
23 | 25 | #include "litert/cc/litert_element_type.h"
|
24 | 26 | #include "litert/cc/litert_expected.h"
|
| 27 | +#include "litert/cc/litert_layout.h" |
25 | 28 | #include "litert/cc/litert_macros.h"
|
26 | 29 | #include "litert/cc/litert_model.h"
|
| 30 | +#include "litert/cc/litert_tensor_buffer.h" |
27 | 31 | #include "litert/core/util/tensor_type_util.h"
|
28 | 32 | #include "tflite/c/c_api_opaque.h"
|
29 | 33 | #include "tflite/c/c_api_types.h"
|
| 34 | +#include "tflite/c/common.h" |
30 | 35 |
|
31 | 36 | namespace litert::internal {
|
32 | 37 |
|
@@ -78,27 +83,60 @@ Expected<ElementType> ConvertElementType(TfLiteType tfl_type) {
|
78 | 83 | }
|
79 | 84 | }
|
80 | 85 |
|
81 |
| -Expected<RankedTensorType> ConvertTensorType( |
| 86 | +Expected<Layout> ConvertTensorLayout( |
82 | 87 | const TfLiteOpaqueTensor* tfl_opaque_tensor) {
|
83 |
| - auto tfl_type = TfLiteOpaqueTensorType(tfl_opaque_tensor); |
84 |
| - auto element_type = ConvertElementType(tfl_type); |
85 |
| - if (!element_type) { |
86 |
| - return Unexpected(element_type.Error()); |
87 |
| - } |
88 |
| - |
89 | 88 | size_t rank = TfLiteOpaqueTensorNumDims(tfl_opaque_tensor);
|
90 | 89 | Dimensions dimensions(rank);
|
91 | 90 | for (size_t i = 0; i < rank; ++i) {
|
92 | 91 | dimensions[i] = TfLiteOpaqueTensorDim(tfl_opaque_tensor, i);
|
93 | 92 | }
|
| 93 | + return Layout(std::move(dimensions)); |
| 94 | +} |
94 | 95 |
|
95 |
| - return RankedTensorType(*element_type, Layout(std::move(dimensions))); |
| 96 | +Expected<RankedTensorType> ConvertTensorType( |
| 97 | + const TfLiteOpaqueTensor* tfl_opaque_tensor) { |
| 98 | + auto tfl_type = TfLiteOpaqueTensorType(tfl_opaque_tensor); |
| 99 | + LITERT_ASSIGN_OR_RETURN(auto element_type, ConvertElementType(tfl_type)); |
| 100 | + LITERT_ASSIGN_OR_RETURN(auto layout, ConvertTensorLayout(tfl_opaque_tensor)); |
| 101 | + return RankedTensorType(element_type, std::move(layout)); |
96 | 102 | }
|
97 | 103 |
|
98 |
| -Expected<size_t> GetTensorSize(const TfLiteOpaqueTensor* tfl_opaque_tensor) { |
| 104 | +Expected<TensorBuffer> CreateHostTensorBufferFromTflTensor( |
| 105 | + TfLiteOpaqueContext* tfl_context, |
| 106 | + const TfLiteOpaqueTensor* tfl_opaque_tensor) { |
99 | 107 | LITERT_ASSIGN_OR_RETURN(auto tensor_type,
|
100 | 108 | ConvertTensorType(tfl_opaque_tensor));
|
101 |
| - return GetNumPackedBytes(static_cast<LiteRtRankedTensorType>(tensor_type)); |
| 109 | + void* host_mem_addr = TfLiteOpaqueTensorData(tfl_opaque_tensor); |
| 110 | + size_t buffer_size = TfLiteOpaqueTensorByteSize(tfl_opaque_tensor); |
| 111 | + LITERT_ASSIGN_OR_RETURN(auto tensor_buffer, |
| 112 | + TensorBuffer::CreateFromHostMemory( |
| 113 | + tensor_type, host_mem_addr, buffer_size)); |
| 114 | + return tensor_buffer; |
| 115 | +} |
| 116 | + |
| 117 | +Expected<void> ResizeTensor(const LiteRtLayout& layout, |
| 118 | + TfLiteOpaqueContext* tfl_context, |
| 119 | + TfLiteOpaqueTensor* tfl_opaque_tensor) { |
| 120 | + // TFL tensors don't support strides. |
| 121 | + if (layout.strides) { |
| 122 | + return Unexpected(kLiteRtStatusErrorInvalidArgument, |
| 123 | + "Unexpected layout with strides"); |
| 124 | + } |
| 125 | + |
| 126 | + TfLiteIntArray* output_size = TfLiteIntArrayCreate(layout.rank); |
| 127 | + for (auto i = 0; i < layout.rank; ++i) { |
| 128 | + output_size->data[i] = layout.dimensions[i]; |
| 129 | + } |
| 130 | + if (auto status = TfLiteOpaqueContextResizeTensor( |
| 131 | + tfl_context, tfl_opaque_tensor, output_size); |
| 132 | + status != kTfLiteOk) { |
| 133 | + return Unexpected( |
| 134 | + kLiteRtStatusErrorRuntimeFailure, |
| 135 | + absl::StrFormat("Failed to resize TFL tensor %s: %d", |
| 136 | + TfLiteOpaqueTensorName(tfl_opaque_tensor), status)); |
| 137 | + } |
| 138 | + |
| 139 | + return {}; |
102 | 140 | }
|
103 | 141 |
|
104 | 142 | } // namespace litert::internal
|
0 commit comments