tensorflow添加新操作(Op)

参考:https://tensorflow.juejin.im/extend/adding_an_op.html

https://zhuanlan.zhihu.com/p/34168765

为了加入一个定制操作,你需要:

  1. 在 C++ 文件中注册这个新操作。操作的注册为此操作的功能定义了一个接口(规范)。比如,操作的注册定义了此操作的名称和它的输入输出。它还定义了 shape 函数,用于获取张量的形状。
  2. 在 C++ 中实现这个操作。操作的实现称为内核,它是你在步骤 1 中注册的规范的具体实现。对于不同的输入输出类型或架构(比如不同的 CPUs 或 GPUs),可能有多个内核。
  3. 创建一个 Python 包装器(可选)。这个包装器是用于在 Python 中创建操作的公共 API。操作的注册可以产生一个默认的包装器,它可以直接使用,或添加。
  4. 为操作编写一个函数来计算梯度(可选)。
  5. tf.test.compute_gradient_error

1. 定义接口:

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"

using namespace tensorflow;

REGISTER_OP("ZeroOut")
    .Input("to_zero: int32")
    .Output("zeroed: int32")
    .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
      c->set_output(0, c->input(0));
      return Status::OK();
    });

关于命名的备注:操作名称必须首字母大写,而且不能和库中已经注册的其它操作重名。

2. 实现操作的内核

定义接口后,接下来就需要为此操作提供一个或多个内核实现了。
为了实现这些内核,创建一个继承自 OpKernel 的类,并重载 Compute 方法。
Compute 方法有一个类型为 OpKernelContext* 的参数 context,从中可以访问输入和输出张量等有用的信息。

将你的内核加到上面创建的文件中。这个内核的代码形如:

#include "tensorflow/core/framework/op_kernel.h"

using namespace tensorflow;

class ZeroOutOp : public OpKernel {
 public:
  explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {}

  void Compute(OpKernelContext* context) override {
    // 得到输入张量
    const Tensor& input_tensor = context->input(0);
    auto input = input_tensor.flat<int32>();

    // 创建输出张量
    Tensor* output_tensor = NULL;
    OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
                                                     &output_tensor));
    auto output_flat = output_tensor->flat<int32>();

    // 除第一个元素外,输出张量的其它所有元素都设置为 0 
    const int N = input.size();
    for (int i = 1; i < N; i++) {
      output_flat(i) = 0;
    }

    // 如果可能的话,保留第一个输入值
    if (N > 0) output_flat(0) = input(0);
  }
};

ZeroOut 操作加上约束条件:

REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);

这里注册的操作名是ZeroOut,通过上面的语句和ZeroOutOp对应吧,

输入和输出

下面对前面的示例做个总结,一个操作注册可以指定多个输入输出:

REGISTER_OP("MultipleInsAndOuts")
    .Input("y: int32")
    .Input("z: float")
    .Output("a: string")
    .Output("b: int32");
 

相关推荐