深入理解TVM:Python/C++互调(上)
一、概述
TVM已经是一个很庞大的系统,包含了很多的功能模块,其中python和c++的互相调用这个功能模块,没有使用第三方的开源库(boost.python、pybind11等),而是自己实现了一套复杂但精致高效强大的机制,值得好好研究学习。这部分内容很多,一篇文章很难说清楚,我准备把这部分分成上、中、下三篇来说,尽可能的把实现原理讲清楚:
- 上篇:最底层的c++数据结构支撑(围绕c++端PackedFunc)
- 中篇:基于PackedFunc的函数注册(围绕TVM_REGISTER_GLOBAL)
- 下篇:偏上层的python的调用细节(围绕ctypes内置库和python端PackedFunc)
本文讲第一部分,也就是围绕PackedFunc这个类来说,它是python和c++互调的桥梁,此类实现代码在include/tvm/runtime/packed_func.h文件中,这里面还有一个TypedPackedFunc类,它只是PackedFunc的一个wrapper,主要增加了类型检查的功能,开发TVM的c++代码要尽可能的使用这个类,但是我们为了把问题尽可能的简化,只关注PackedFunc这个最底层类,其中用到了下面这几个关键的数据结构:
- TVMValue
- TVMArgs
- TVMPODValue_
- TVMArgValue
- TVMRetValue
- TVMArgsSetter
下面结合代码,逐个来说(注:本文基于fe25b9e7c这个commit,下面所有列出的代码都做了相当大量的精简和修改,一来只为讲清楚原理,二来限于篇幅,大家如果有兴趣了解更多的细节,需要再去github上看实际的实现)
二、TVMValue
这是最基本的一个数据结构,是一个union,实现在include/tvm/runtime/c_runtime_api.h,主要是为了储存c++和其它语言交互时所支持的几种类型的数据,代码很简单(其中DLDataType和DLDevice是两个复合数据类型,限于篇幅,这里不能全部列出来,大家需要自己到github追下细节):
typedef union {
int64_t v_int64;
double v_float64;
void* v_handle;
const char* v_str;
DLDataType v_type;
DLDevice v_device;
} TVMValue;
三、TVMArgs
这个类主要是为了封装传给PackedFunc的所有参数,这个类也比较简单原始,主要基于TVMValue、参数类型编码、参数个数来实现,代码如下:
class TVMArgs {
public:
const TVMValue* values;
const int* type_codes;
int num_args;
TVMArgs(const TVMValue* values,
const int* type_codes,
int num_args) { ... }
inline int size() const { return num_args; }
inline TVMArgValue operator[](int i) const {
return TVMArgValue(values[i], type_codes[i]);
}
};
四、TVMPODValue_
这是一个内部使用的基类,主要主要服务于后面介绍到的TVMArgValue和TVMRetValue,从名字可以看出,这个类主要是处理POD类型的数据,POD是plain old data的缩写,要么是scalar type,要么是trival type,要么是standard layout type,具体可参考cppreference的PODType、is_pod、is_scalar、is_trivial、is_standard_layout等章节。其实关于POD类型,可以单独写一大篇文章,但它不是本文的重点,以后有时间再专门写文章细说。这个类的实现核心是强制类型转换运算符重载(在c++中,类型的名字,包括类的名字本身也是一种运算符,即类型强制转换运算符),如下面代码所示:
class TVMPODValue_ {
public:
operator double() const { return value_.v_float64; }
operator int64_t() const { return value_.v_int64; }
operator void*() const { return value_.v_handle; }
template <typename T>
T* ptr() const { return static_cast<T*>(value_.v_handle); }
protected:
TVMValue value_;
int type_code_;
};
五、TVMArgValue
这个类继承自前面的TVMPODValue_类,用作表示PackedFunc的一个参数,它和TVMPODValue_的区别是扩充了一些数据类型的支持,比如string、PackedFunc、TypedPackedFunc等,其中对后两个的支持是在c++代码中能够调用python函数的根本原因。这个类只使用所保存的underlying data,而不会去做释放,代码如下:
class TVMArgValue : public TVMPODValue_ {
public:
TVMArgValue() {}
TVMArgValue(TVMValue value, int type_code)
: TVMPODValue_(value, type_code) {}
operator std::string() const {}
operator PackedFunc() const { return *ptr<PackedFunc>(); }
const TVMValue& value() const { return value_; }
template <typename T>
inline operator T() const;
inline operator DLDataType() const;
inline operator DataType() const;
};
六、TVMRetValue
这个类也是继承自TVMPODValue_类,主要作用是作为存放调用PackedFunc返回值的容器,它和TVMArgValue的区别是,它会管理所保存的underlying data,会对其做释放。这个类主要由四部分构成:
- 构造和析构函数
- 对强制类型转换运算符重载的扩展
- 对赋值运算符的重载
- 辅助函数,包括释放资源的Clear函数
代码如下:
class TVMRetValue : public TVMPODValue_ {
public:
// ctor and dtor, dtor will release related buffer
TVMRetValue() {}
~TVMRetValue() { this->Clear(); }
// conversion operators
operator std::string() const { ret