I’m trying to fix the current reshape problem in the keras frontend, but not yet succeeded.
This is the testcase I cannot pass.
def test_forward_reshape(): data = keras.layers.Input(shape=(32,32,3)) x = keras.layers.Reshape(target_shape=(32,32,3))(data) x = keras.layers.GlobalAveragePooling2D()(x) keras_model = keras.models.Model(data, x) verify_keras_frontend(keras_model)
The problem is that the
target_shape parameter in
Reshape is in HWC format but the input tensor is in (N)CHW format.
Is there a good way to solve this problem?
My ongoing work is here: