Optimizing memory loads


#1

Hi, I’m trying to reuse the input reads because they’re expensive for my use case. Here’s a simple example:

	in_size = 10
	filter_size = 3
	out_size = in_size - filter_size + 1

	A = tvm.placeholder((in_size,), name = 'Input')
	ra = tvm.reduce_axis((0, filter_size), name='ra')
	Out = tvm.compute((out_size,), lambda x: tvm.sum(A[x + ra], axis=[ra]), name='Out')

	s = tvm.create_schedule(Out.op)
	AL = s.cache_read(A, "local", [Out])
	s[AL].compute_at(s[Out], Out.op.axis[0])

	print(tvm.lower(s, [A, Out], simple_mode=True))

Produces:

	// attr [Input.local] storage_scope = "local"
	allocate Input.local[float32 * 3]
	produce Out {
	  for (x, 0, 8) {
		produce Input.local {
		  for (ax0, 0, 3) {
			Input.local[ax0] = Input[(x + ax0)] /* here there are 3 loads for each x iteration */
		  }
		}
		Out[x] = 0.000000f
		for (ra, 0, 3) {
		  Out[x] = (Out[x] + Input.local[ra])
		}
	  }
	}

But I need something like:

	// attr [Input.local] storage_scope = "local"
	allocate Input.local[float32 * 3]
	produce Out {
	  Input.local[0] = Input[0]
	  Input.local[1] = Input[1]
	  for (x, 0, 8) {
		Input.local[2] = Input[(x + 2)] /* only one load inside the x loop */
		Out[x] = 0.000000f
		for (ra, 0, 3) {
		  Out[x] = (Out[x] + Input.local[ra])
		}
		Input.local[0] = Input.local[1]
		Input.local[1] = Input.local[2] /* some way to permute the loaded local values or their pointers*/
	  }
	}

I tried applying double buffering, but s[AL].double_buffer() transforms the local buffer usage to Input.local[float32 * 2 * 3] so that’s not good. s[AL].compute_at(s[Out], Out.op.reduce_axis[0]), s[AL].double_buffer() will transform local buffer usage to Input.local[float32 * 2 * 1] so that’s not good either.

Is there a way to generate the above schedule using the current TVM functions?


#2

Are you targeting TVM to HLS? It is something like a shift register. Doing shift register is expensive in software level, but it is normal in hardware design. If you want to support this, we may talk about hardware specialized optimization in details.


#3

Here is an example that uses Input.local as a circular buffer. This is a simplified example where Input.local[x] is only one float element, but in a real use case Input.local[x] might contain an array of data, so that’s why avoiding expensive memory retransfers is needed.

// attr [Input.local] storage_scope = "local"
allocate Input.local[float32 * 3]
produce Out {
  Input.local[1] = Input[0]
  Input.local[2] = Input[1]
  for (x, 0, 8) {
	Input.local[x % 3] = Input[(x + 2)] /* only one load inside the x loop */
	Out[x] = 0.000000f
	for (ra, 0, 3) {
	  Out[x] = (Out[x] + Input.local[ra])
	}
  }
}

It seems like Halide automatically detects this case and injects the circular buffer optimization (http://halide-lang.org/tutorials/tutorial_lesson_08_scheduling_2.html). Implemented in Halide/src/StorageFolding.cpp.