If I don't want to unpack batch normalization


#1

I have a question regarding batch normalization in NNVM.
By default (with opt_level=0), it seems that NNM unpacks batch normalization into several unit operators such as add_scalar, sqrt, rdiv_scalar, and elemwise_mul.
Here, If I want to keep batch normalization as a unit operator and fuse operators like @tvm_op(…, …, , func_name=”fuse_conv2d_batch_norm_relu_…", …), what steps do I need to do?

So far I notice that,

  1. Computation of batch normalization is implemented in Python (./tvm/topi/python/topi/nn/batch_norm.py) as well as C++ (./tvm/topi/include/topi/nn/batch_norm.h).
  2. The computation and schedule is not registered to NNVM, as seen in ./python/nnvm/top/nn.py) whereas it’s registered to TVM as in ./tvm/topi/src/topi.cc).
  3. Operator pattern for batch normalization is registered as BROADCAST, as seen in tag scope of batch_norm_inference() in batch_norm.py.

I tried to register batch normalization op to NNVM and execute the code, but it eventually fails in GraphFuseCompile(), specifically when fcompute (essentially, the lambda function of batch normalization code in batch_norm.py) is called in compute() in ./tvm/python/tvm/api.py. Since the fcompute is a lamba function, I’m not sure how I could efficiently debug or trace the code.

Do you have any comments or suggestions? Thanks!

Won


#2

unfortunately it is hard to do so. You could mark batchnorm as opaque and provide a manual schedule for that.

But note that TVM automatically fuses things together again so unpacking batchnorm actually makes followup optimizations easier


#3

Can I ask why do you want a single batchnorm? you have special implement for it?


#4

I imagine there is a use case when a fully optimized batch norm intrinsic or instruction is provided from a particular hardware platform.


#5

Hmm… that will be interesting. well… if it was the case, maybe we could have some arguments or system env variables tell nnvm not to unpack a specific operator, (for now only batchnorm and dropout)