Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "use tensor.shape bug not paddle.shape(tensor)" #8922

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion ppdet/modeling/architectures/detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _forward(self):
else:
bbox, bbox_num, mask = self.post_process(
preds, self.inputs['im_shape'], self.inputs['scale_factor'],
self.inputs['image'])[2:].shape
paddle.shape(self.inputs['image'])[2:])

output = {'bbox': bbox, 'bbox_num': bbox_num}
if self.with_mask:
Expand Down
2 changes: 1 addition & 1 deletion ppdet/modeling/architectures/pose3d_metro.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def orthographic_projection(X, camera):
"""
camera = camera.reshape((-1, 1, 3))
X_trans = X[:, :, :2] + camera[:, :, 1:]
shape = X_trans.shape
shape = paddle.shape(X_trans)
X_2d = (camera[:, :, 0] * X_trans.reshape((shape[0], -1))).reshape(shape)
return X_2d

Expand Down
10 changes: 5 additions & 5 deletions ppdet/modeling/backbones/trans_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ def __init__(self, word_size, position_embeddings_size, word_type_size,
self.dropout = nn.Dropout(dropout_prob)

def forward(self, x, token_type_ids=None, position_ids=None):
seq_len = x.shape[1]
seq_len = paddle.shape(x)[1]
if position_ids is None:
position_ids = paddle.arange(seq_len).unsqueeze(0).expand_as(x)
if token_type_ids is None:
token_type_ids = paddle.zeros(x.shape)
token_type_ids = paddle.zeros(paddle.shape(x))

word_embs = self.word_embeddings(x)
position_embs = self.position_embeddings(position_ids)
Expand Down Expand Up @@ -82,7 +82,7 @@ def forward(self, x, attention_mask, head_mask=None):
key = self.key(x)
value = self.value(x)

query_dim1, query_dim2 = query.shape[:-1]
query_dim1, query_dim2 = paddle.shape(query)[:-1]
new_shape = [
query_dim1, query_dim2, self.num_attention_heads,
self.attention_head_size
Expand All @@ -102,7 +102,7 @@ def forward(self, x, attention_mask, head_mask=None):

context = paddle.matmul(attention_value, value).transpose(perm=(0, 2, 1,
3))
ctx_dim1, ctx_dim2 = context.shape[:-2]
ctx_dim1, ctx_dim2 = paddle.shape(context)[:-2]
new_context_shape = [
ctx_dim1,
ctx_dim2,
Expand Down Expand Up @@ -303,7 +303,7 @@ def init_weights(self, module):
module.bias.set_value(paddle.zeros(shape=module.bias.shape))

def forward(self, x):
batchsize, seq_len = x.shape[:2]
batchsize, seq_len = paddle.shape(x)[:2]
input_ids = paddle.zeros((batchsize, seq_len), dtype="int64")
position_ids = paddle.arange(
seq_len, dtype="int64").unsqueeze(0).expand_as(input_ids)
Expand Down
6 changes: 3 additions & 3 deletions ppdet/modeling/backbones/transformer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def drop_path(x, drop_prob=0., training=False):
if drop_prob == 0. or not training:
return x
keep_prob = paddle.to_tensor(1 - drop_prob, dtype=x.dtype)
shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1)
random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
random_tensor = paddle.floor(random_tensor) # binarize
output = x.divide(keep_prob) * random_tensor
Expand Down Expand Up @@ -85,7 +85,7 @@ def window_partition(x, window_size):
windows: windows after partition with [B * num_windows, window_size, window_size, C].
(Hp, Wp): padded height and width before partition
"""
B, H, W, C = x.shape
B, H, W, C = paddle.shape(x)

pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
Expand Down Expand Up @@ -116,7 +116,7 @@ def window_unpartition(x, pad_hw, num_hw, hw):
Hp, Wp = pad_hw
num_h, num_w = num_hw
H, W = hw
B, window_size, _, C = x.shape
B, window_size, _, C = paddle.shape(x)
B = B // (num_h * num_w)
x = x.reshape([B, num_h, num_w, window_size, window_size, C])
x = x.transpose([0, 1, 3, 2, 4, 5]).reshape([B, Hp, Wp, C])
Expand Down
2 changes: 1 addition & 1 deletion ppdet/modeling/backbones/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def __init__(self,
self.proj_drop = nn.Dropout(proj_drop)

def forward(self, x, rel_pos_bias=None):
x_shape = x.shape
x_shape = paddle.shape(x)
N, C = x_shape[1], x_shape[2]

qkv_bias = None
Expand Down
4 changes: 2 additions & 2 deletions ppdet/modeling/backbones/vit_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def add_decomposed_rel_pos(self, attn, q, h, w):
return attn.reshape([B, h * w, h * w])

def forward(self, x):
B, H, W, C = x.shape
B, H, W, C = paddle.shape(x)

if self.q_bias is not None:
qkv_bias = paddle.concat(
Expand Down Expand Up @@ -567,7 +567,7 @@ def get_2d_sincos_position_embedding(self, h, w, temperature=10000.):

def forward(self, inputs):
x = self.patch_embed(inputs['image']).transpose([0, 2, 3, 1])
B, Hp, Wp, _ = x.shape
B, Hp, Wp, _ = paddle.shape(x)

if self.use_abs_pos:
x = x + self.get_2d_sincos_position_embedding(Hp, Wp)
Expand Down
4 changes: 2 additions & 2 deletions ppdet/modeling/backbones/vitpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def drop_path(x, drop_prob=0., training=False):
if drop_prob == 0. or not training:
return x
keep_prob = paddle.to_tensor(1.0 - drop_prob).astype(x.dtype)
shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
shape = (paddle.shape(x)[0], ) + (1, ) * (x.ndim - 1)
random_tensor = keep_prob + paddle.rand(shape).astype(x.dtype)
random_tensor = paddle.floor(random_tensor) # binarize
output = x.divide(keep_prob) * random_tensor
Expand Down Expand Up @@ -303,7 +303,7 @@ def _init_weights(self):

def forward_features(self, x):

B = x.shape[0]
B = paddle.shape(x)[0]
x = self.patch_embed(x)
B, D, Hp, Wp = x.shape
x = x.flatten(2).transpose([0, 2, 1])
Expand Down
6 changes: 3 additions & 3 deletions ppdet/modeling/heads/fcosr_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def _generate_anchors(self, feats):
if self.trt:
anchor_points = []
for feat, stride in zip(feats, self.fpn_strides):
_, _, h, w = feat.shape
_, _, h, w = paddle.shape(feat)
anchor, _ = anchor_generator(
feat,
stride * 4,
Expand All @@ -206,7 +206,7 @@ def _generate_anchors(self, feats):
stride_tensor = []
num_anchors_list = []
for feat, stride in zip(feats, self.fpn_strides):
_, _, h, w = feat.shape
_, _, h, w = paddle.shape(feat)
shift_x = (paddle.arange(end=w) + 0.5) * stride
shift_y = (paddle.arange(end=h) + 0.5) * stride
shift_y, shift_x = paddle.meshgrid(shift_y, shift_x)
Expand Down Expand Up @@ -263,7 +263,7 @@ def forward_eval(self, feats, target=None):
cls_pred_list, reg_pred_list = [], []
anchor_points, _, _ = self._generate_anchors(feats)
for stride, feat, scale in zip(self.fpn_strides, feats, self.scales):
b, _, h, w = feat.shape
b, _, h, w = paddle.shape(feat)
# cls
cls_feat = feat
for cls_layer in self.stem_cls:
Expand Down
4 changes: 2 additions & 2 deletions ppdet/modeling/heads/gfl_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def forward(self, fpn_feats):
if not self.training:
cls_score = F.sigmoid(cls_score.transpose([0, 2, 3, 1]))
bbox_pred = bbox_pred.transpose([0, 2, 3, 1])
b, cell_h, cell_w, _ = cls_score.shape
b, cell_h, cell_w, _ = paddle.shape(cls_score)
y, x = self.get_single_level_center_point(
[cell_h, cell_w], stride, cell_offset=self.cell_offset)
center_points = paddle.stack([x, y], axis=-1)
Expand Down Expand Up @@ -515,7 +515,7 @@ def forward(self, fpn_feats):
if not self.training:
cls_score = F.sigmoid(cls_score.transpose([0, 2, 3, 1]))
bbox_pred = bbox_pred.transpose([0, 2, 3, 1])
b, cell_h, cell_w, _ = cls_score.shape
b, cell_h, cell_w, _ = paddle.shape(cls_score)
y, x = self.get_single_level_center_point(
[cell_h, cell_w], stride, cell_offset=self.cell_offset)
center_points = paddle.stack([x, y], axis=-1)
Expand Down
9 changes: 5 additions & 4 deletions ppdet/modeling/heads/mask_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,12 +224,13 @@ def forward_test(self,
if self.num_classes == 1:
mask_out = F.sigmoid(mask_logit)[:, 0, :, :]
else:
num_masks = mask_logit.shape[0]
num_masks = paddle.shape(mask_logit)[0]
index = paddle.arange(num_masks).cast('int32')
mask_out = mask_logit[index, labels]
mask_out_shape = mask_out.shape
mask_out = paddle.reshape(mask_out,
index.shape + [mask_out_shape[-2]] + [mask_out_shape[-1]])
mask_out_shape = paddle.shape(mask_out)
mask_out = paddle.reshape(mask_out, [
paddle.shape(index), mask_out_shape[-2], mask_out_shape[-1]
])
mask_out = F.sigmoid(mask_out)
return mask_out

Expand Down
2 changes: 1 addition & 1 deletion ppdet/modeling/heads/pico_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ def forward_train(self, fpn_feats):

cls_score_out = cls_score.transpose([0, 2, 3, 1])
bbox_pred = reg_pred.transpose([0, 2, 3, 1])
b, cell_h, cell_w, _ = cls_score_out.shape
b, cell_h, cell_w, _ = paddle.shape(cls_score_out)
y, x = self.get_single_level_center_point(
[cell_h, cell_w], stride, cell_offset=self.cell_offset)
center_points = paddle.stack([x, y], axis=-1)
Expand Down
6 changes: 3 additions & 3 deletions ppdet/modeling/heads/ppyoloe_r_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def _generate_anchors(self, feats):
if self.trt:
anchor_points = []
for feat, stride in zip(feats, self.fpn_strides):
_, _, h, w = feat.shape
_, _, h, w = paddle.shape(feat)
anchor, _ = anchor_generator(
feat,
stride * 4,
Expand All @@ -156,7 +156,7 @@ def _generate_anchors(self, feats):
stride_tensor = []
num_anchors_list = []
for feat, stride in zip(feats, self.fpn_strides):
_, _, h, w = feat.shape
_, _, h, w = paddle.shape(feat)
shift_x = (paddle.arange(end=w) + 0.5) * stride
shift_y = (paddle.arange(end=h) + 0.5) * stride
shift_y, shift_x = paddle.meshgrid(shift_y, shift_x)
Expand Down Expand Up @@ -210,7 +210,7 @@ def forward_eval(self, feats):
cls_score_list, reg_box_list = [], []
anchor_points, _, _ = self._generate_anchors(feats)
for i, (feat, stride) in enumerate(zip(feats, self.fpn_strides)):
b, _, h, w = feat.shape
b, _, h, w = paddle.shape(feat)
l = h * w
# cls
avg_feat = F.adaptive_avg_pool2d(feat, (1, 1))
Expand Down
6 changes: 3 additions & 3 deletions ppdet/modeling/heads/s2anet_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def forward(self, feats, targets=None):
for i, feat in enumerate(feats):
# get shape
B = feat.shape[0]
H, W = feat.shape[2], feat.shape[3]
H, W = paddle.shape(feat)[2], paddle.shape(feat)[3]

NA = H * W
num_anchors_list.append(NA)
Expand Down Expand Up @@ -324,7 +324,7 @@ def forward(self, feats, targets=None):

def get_bboxes(self, head_outs):
perd_bboxes_list, pred_scores_list = head_outs
batch = pred_scores_list[0].shape[0]
batch = paddle.shape(pred_scores_list[0])[0]
bboxes, bbox_num = [], []
for i in range(batch):
pred_scores_per_image = [t[i] for t in pred_scores_list]
Expand Down Expand Up @@ -712,7 +712,7 @@ def rbox2poly(self, rboxes):
to
polys: [x0,y0,x1,y1,x2,y2,x3,y3]
"""
N = rboxes.shape[0]
N = paddle.shape(rboxes)[0]

x_ctr = rboxes[:, 0]
y_ctr = rboxes[:, 1]
Expand Down
44 changes: 22 additions & 22 deletions ppdet/modeling/heads/solov2_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,16 +131,16 @@ def forward(self, inputs):
if i == (self.range_level - 1):
input_feat = input_p
x_range = paddle.linspace(
-1, 1, input_feat.shape[-1], dtype='float32')
-1, 1, paddle.shape(input_feat)[-1], dtype='float32')
y_range = paddle.linspace(
-1, 1, input_feat.shape[-2], dtype='float32')
-1, 1, paddle.shape(input_feat)[-2], dtype='float32')
y, x = paddle.meshgrid([y_range, x_range])
x = paddle.unsqueeze(x, [0, 1])
y = paddle.unsqueeze(y, [0, 1])
y = paddle.expand(
y, shape=[input_feat.shape[0], 1, -1, -1])
y, shape=[paddle.shape(input_feat)[0], 1, -1, -1])
x = paddle.expand(
x, shape=[input_feat.shape[0], 1, -1, -1])
x, shape=[paddle.shape(input_feat)[0], 1, -1, -1])
coord_feat = paddle.concat([x, y], axis=1)
input_p = paddle.concat([input_p, coord_feat], axis=1)
feat_all_level = paddle.add(feat_all_level,
Expand Down Expand Up @@ -271,7 +271,7 @@ def _split_feats(self, feats):
align_mode=0,
mode='bilinear'), feats[1], feats[2], feats[3], F.interpolate(
feats[4],
size=feats[3].shape[-2:],
size=paddle.shape(feats[3])[-2:],
mode='bilinear',
align_corners=False,
align_mode=0))
Expand Down Expand Up @@ -300,16 +300,16 @@ def _get_output_single(self, input, idx):
ins_kernel_feat = input
# CoordConv
x_range = paddle.linspace(
-1, 1, ins_kernel_feat.shape[-1], dtype='float32')
-1, 1, paddle.shape(ins_kernel_feat)[-1], dtype='float32')
y_range = paddle.linspace(
-1, 1, ins_kernel_feat.shape[-2], dtype='float32')
-1, 1, paddle.shape(ins_kernel_feat)[-2], dtype='float32')
y, x = paddle.meshgrid([y_range, x_range])
x = paddle.unsqueeze(x, [0, 1])
y = paddle.unsqueeze(y, [0, 1])
y = paddle.expand(
y, shape=[ins_kernel_feat.shape[0], 1, -1, -1])
y, shape=[paddle.shape(ins_kernel_feat)[0], 1, -1, -1])
x = paddle.expand(
x, shape=[ins_kernel_feat.shape[0], 1, -1, -1])
x, shape=[paddle.shape(ins_kernel_feat)[0], 1, -1, -1])
coord_feat = paddle.concat([x, y], axis=1)
ins_kernel_feat = paddle.concat([ins_kernel_feat, coord_feat], axis=1)

Expand Down Expand Up @@ -358,7 +358,7 @@ def get_loss(self, cate_preds, kernel_preds, ins_pred, ins_labels,
loss_ins (Tensor): The instance loss Tensor of SOLOv2 network.
loss_cate (Tensor): The category loss Tensor of SOLOv2 network.
"""
batch_size = grid_order_list[0].shape[0]
batch_size = paddle.shape(grid_order_list[0])[0]
ins_pred_list = []
for kernel_preds_level, grid_orders_level in zip(kernel_preds,
grid_order_list):
Expand All @@ -368,25 +368,25 @@ def get_loss(self, cate_preds, kernel_preds, ins_pred, ins_labels,
grid_orders_level = paddle.reshape(grid_orders_level, [-1])
reshape_pred = paddle.reshape(
kernel_preds_level,
shape=(kernel_preds_level.shape[0],
kernel_preds_level.shape[1], -1))
shape=(paddle.shape(kernel_preds_level)[0],
paddle.shape(kernel_preds_level)[1], -1))
reshape_pred = paddle.transpose(reshape_pred, [0, 2, 1])
reshape_pred = paddle.reshape(
reshape_pred, shape=(-1, reshape_pred.shape[2]))
reshape_pred, shape=(-1, paddle.shape(reshape_pred)[2]))
gathered_pred = paddle.gather(reshape_pred, index=grid_orders_level)
gathered_pred = paddle.reshape(
gathered_pred,
shape=[batch_size, -1, gathered_pred.shape[1]])
shape=[batch_size, -1, paddle.shape(gathered_pred)[1]])
cur_ins_pred = ins_pred
cur_ins_pred = paddle.reshape(
cur_ins_pred,
shape=(cur_ins_pred.shape[0],
cur_ins_pred.shape[1], -1))
shape=(paddle.shape(cur_ins_pred)[0],
paddle.shape(cur_ins_pred)[1], -1))
ins_pred_conv = paddle.matmul(gathered_pred, cur_ins_pred)
cur_ins_pred = paddle.reshape(
ins_pred_conv,
shape=(-1, ins_pred.shape[-2],
ins_pred.shape[-1]))
shape=(-1, paddle.shape(ins_pred)[-2],
paddle.shape(ins_pred)[-1]))
ins_pred_list.append(cur_ins_pred)

num_ins = paddle.sum(fg_num)
Expand Down Expand Up @@ -423,7 +423,7 @@ def get_prediction(self, cate_preds, kernel_preds, seg_pred, im_shape,
seg_masks (Tensor): The prediction score of each segmentation.
"""
num_levels = len(cate_preds)
featmap_size = seg_pred.shape[-2:]
featmap_size = paddle.shape(seg_pred)[-2:]
seg_masks_list = []
cate_labels_list = []
cate_scores_list = []
Expand All @@ -449,7 +449,7 @@ def get_prediction(self, cate_preds, kernel_preds, seg_pred, im_shape,
seg_masks, cate_labels, cate_scores = self.get_seg_single(
cate_pred_list, seg_pred_list, kernel_pred_list, featmap_size,
im_shape[idx], scale_factor[idx][0])
bbox_num = cate_labels.shape[0:1]
bbox_num = paddle.shape(cate_labels)[0:1]
return seg_masks, cate_labels, cate_scores, bbox_num

def get_seg_single(self, cate_preds, seg_preds, kernel_preds, featmap_size,
Expand All @@ -462,7 +462,7 @@ def get_seg_single(self, cate_preds, seg_preds, kernel_preds, featmap_size,
w = paddle.cast(im_shape[1], 'int32')
upsampled_size_out = [featmap_size[0] * 4, featmap_size[1] * 4]

y = paddle.zeros(shape=cate_preds.shape, dtype='float32')
y = paddle.zeros(shape=paddle.shape(cate_preds), dtype='float32')
inds = paddle.where(cate_preds > self.score_threshold, cate_preds, y)
inds = paddle.nonzero(inds)
cate_preds = paddle.reshape(cate_preds, shape=[-1])
Expand Down Expand Up @@ -507,7 +507,7 @@ def get_seg_single(self, cate_preds, seg_preds, kernel_preds, featmap_size,
seg_masks = paddle.cast(seg_masks, 'float32')
sum_masks = paddle.sum(seg_masks, axis=[1, 2])

y = paddle.zeros(shape=sum_masks.shape, dtype='float32')
y = paddle.zeros(shape=paddle.shape(sum_masks), dtype='float32')
keep = paddle.where(sum_masks > strides, sum_masks, y)
keep = paddle.nonzero(keep)
keep = paddle.squeeze(keep, axis=[1])
Expand Down
2 changes: 1 addition & 1 deletion ppdet/modeling/heads/yolof_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def forward(self, feats, targets=None):
objectness = self.object_pred(conv_reg_feat)
bboxes_reg = self.bbox_pred(conv_reg_feat)

N, C, H, W = cls_logits.shape[:]
N, C, H, W = paddle.shape(cls_logits)[:]
cls_logits = cls_logits.reshape((N, self.na, self.num_classes, H, W))
objectness = objectness.reshape((N, self.na, 1, H, W))
norm_cls_logits = cls_logits + objectness - paddle.log(
Expand Down