深入理解TVM:Python/C++互调

深入理解TVM:Python/C++互调(上)

一、概述

TVM已经是一个很庞大的系统,包含了很多的功能模块,其中python和c++的互相调用这个功能模块,没有使用第三方的开源库(boost.python、pybind11等),而是自己实现了一套复杂但精致高效强大的机制,值得好好研究学习。这部分内容很多,一篇文章很难说清楚,我准备把这部分分成上、中、下三篇来说,尽可能的把实现原理讲清楚:

  1. 上篇:最底层的c++数据结构支撑(围绕c++端PackedFunc)
  2. 中篇:基于PackedFunc的函数注册(围绕TVM_REGISTER_GLOBAL)
  3. 下篇:偏上层的python的调用细节(围绕ctypes内置库和python端PackedFunc)

本文讲第一部分,也就是围绕PackedFunc这个类来说,它是python和c++互调的桥梁,此类实现代码在include/tvm/runtime/packed_func.h文件中,这里面还有一个TypedPackedFunc类,它只是PackedFunc的一个wrapper,主要增加了类型检查的功能,开发TVM的c++代码要尽可能的使用这个类,但是我们为了把问题尽可能的简化,只关注PackedFunc这个最底层类,其中用到了下面这几个关键的数据结构:

  • TVMValue
  • TVMArgs
  • TVMPODValue_
  • TVMArgValue
  • TVMRetValue
  • TVMArgsSetter

下面结合代码,逐个来说(注:本文基于fe25b9e7c这个commit,下面所有列出的代码都做了相当大量的精简和修改,一来只为讲清楚原理,二来限于篇幅,大家如果有兴趣了解更多的细节,需要再去github上看实际的实现)

二、TVMValue

这是最基本的一个数据结构,是一个union,实现在include/tvm/runtime/c_runtime_api.h,主要是为了储存c++和其它语言交互时所支持的几种类型的数据,代码很简单(其中DLDataType和DLDevice是两个复合数据类型,限于篇幅,这里不能全部列出来,大家需要自己到github追下细节):

typedef union {
  int64_t v_int64;
  double v_float64;
  void* v_handle;
  const char* v_str;
  DLDataType v_type;
  DLDevice v_device;
} TVMValue;

三、TVMArgs

这个类主要是为了封装传给PackedFunc的所有参数,这个类也比较简单原始,主要基于TVMValue、参数类型编码、参数个数来实现,代码如下:

class TVMArgs {
 public:
  const TVMValue* values;
  const int* type_codes;
  int num_args;
  TVMArgs(const TVMValue* values, 
          const int* type_codes, 
          int num_args) { ... }

  inline int size() const { return num_args; }
  inline TVMArgValue operator[](int i) const { 
      return TVMArgValue(values[i], type_codes[i]); 
  }
};

四、TVMPODValue_

这是一个内部使用的基类,主要主要服务于后面介绍到的TVMArgValue和TVMRetValue,从名字可以看出,这个类主要是处理POD类型的数据,POD是plain old data的缩写,要么是scalar type,要么是trival type,要么是standard layout type,具体可参考cppreference的PODTypeis_podis_scalaris_trivialis_standard_layout等章节。其实关于POD类型,可以单独写一大篇文章,但它不是本文的重点,以后有时间再专门写文章细说。这个类的实现核心是强制类型转换运算符重载(在c++中,类型的名字,包括类的名字本身也是一种运算符,即类型强制转换运算符),如下面代码所示:

class TVMPODValue_ {
 public:
  operator double() const { return value_.v_float64; }
  operator int64_t() const { return value_.v_int64; }
  operator void*() const { return value_.v_handle; }
  template <typename T>
  T* ptr() const { return static_cast<T*>(value_.v_handle); }
​
 protected:
  TVMValue value_;
  int type_code_;
};

五、TVMArgValue

这个类继承自前面的TVMPODValue_类,用作表示PackedFunc的一个参数,它和TVMPODValue_的区别是扩充了一些数据类型的支持,比如string、PackedFunc、TypedPackedFunc等,其中对后两个的支持是在c++代码中能够调用python函数的根本原因。这个类只使用所保存的underlying data,而不会去做释放,代码如下:

class TVMArgValue : public TVMPODValue_ {
 public:
  TVMArgValue() {}
  TVMArgValue(TVMValue value, int type_code) 
  : TVMPODValue_(value, type_code) {}
​
  operator std::string() const {}
  operator PackedFunc() const { return *ptr<PackedFunc>(); }
  const TVMValue& value() const { return value_; }

  template <typename T>
  inline operator T() const;
  inline operator DLDataType() const;
  inline operator DataType() const;
};

六、TVMRetValue

这个类也是继承自TVMPODValue_类,主要作用是作为存放调用PackedFunc返回值的容器,它和TVMArgValue的区别是,它会管理所保存的underlying data,会对其做释放。这个类主要由四部分构成:

  • 构造和析构函数
  • 对强制类型转换运算符重载的扩展
  • 对赋值运算符的重载
  • 辅助函数,包括释放资源的Clear函数

代码如下:

class TVMRetValue : public TVMPODValue_ {
 public:
  // ctor and dtor, dtor will release related buffer
  TVMRetValue() {}
  ~TVMRetValue() { this->Clear(); }
​
  // conversion operators
  operator std::string() const { ret
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值