Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: The 'getitem' operation does not support the type [None, Int64]. #596

Open
sevennotmouse opened this issue Sep 15, 2023 · 1 comment
Assignees

Comments

@sevennotmouse
Copy link

sevennotmouse commented Sep 15, 2023

2023昇腾AI创新大赛-算法创新-VisionLAN模型迁移复现
background:我们正在将pytorch源代码中的train_LF_1.py迁移至mindspore。遵循mindspore的训练范式,我们分别定义好模型、数据集、优化器、损失函数等,通过构造WithLossCell、调用Model类,使用model.train进行训练。


遇到的问题:训练时报了如下错误,是关于getitem这个方法的:
RuntimeError: The 'getitem' operation does not support the type [None, Int64].The supported types of overload function getitem is: [Tuple, Slice], [List, Slice], [Tensor, Ellipsis], [Tuple, Tensor], [List, Number], [Tensor, Slice], [Dictionary, String], [Tensor, Tensor], [String, Number], [Tensor, Tuple], [Tensor, None], [Tuple, Number], [Tensor, Number], [Tensor, List].
注:调试环境:华为云modelarts平台的notebook,镜像:mindspore_1.10.0-cann_6.0.1-py_3.7-euler_2.8.3,规格:Ascend: 1*Ascend910|ARM: 24核 96GB


getitem是自定义的lmdbDataset类里定义的一个方法:

class lmdbDataset():
   def __init__(xx):
		xxx
   def __fromwhich__(xx):
		xxx
   def keepratio_resize(xx)
		xxx
   def __len__(self):
        return self.nSamples
   def __getitem__(self, index):
        fromwhich = self.__fromwhich__()
        if self.global_state == 'Train':
            index = random.randint(0,self.maxlen - 1)
        index = index % self.lengths[fromwhich]
        assert index <= len(self), 'index range error'
        index += 1
        with self.envs[fromwhich].begin(write=False) as txn:
            img_key = 'image-%09d' % index
            try:
                imgbuf = txn.get(img_key.encode())
                buf = six.BytesIO()
                buf.write(imgbuf)
                buf.seek(0)
                img = Image.open(buf).convert('RGB')
            except:
                print('Corrupted image for %d' % index)
                return self[index + 1]
            label_key = 'label-%09d' % index
            # label = str(txn.get(label_key.encode()))
            # if python3
            label = str(txn.get(label_key.encode()), 'utf-8')
            label = re.sub('[^0-9a-zA-Z]+', '', label)
            
            if (len(label) > 25 or len(label) <= 0) and self.global_state == 'Train':
                print(len(label))
                print(label)
                print('sample too long')
                print(self.global_state)
                return self[index + 1]
            
            img = self.keepratio_resize(img, self.global_state)
            if self.transform:
                img = self.transform(img)
            # generate masked_id masked_character remain_string
            label_res, label_sub, label_id =  des_orderlabel(label)
            sample = {'image': img, 'label': label, 'label_res': label_res, 'label_sub': label_sub, 'label_id': label_id}
            #return sample
            return (img,label,label_res,label_sub,label_id)  # 返回元组类型 

我们通过load_dataset函数加载数据集,其中包含调用lmdbDataset类、用mindspore的GeneratorDataset加载数据集和按batchsize划分数据集三步:

def load_dataset():
    # 调用lmdbDataset类
    train_data_set = cfgs.dataset_cfgs['dataset_train'](**cfgs.dataset_cfgs['dataset_train_args'])
#也即train_data_set = lmdbDataset(roots=['./datasets/train/SynthText','./datasets/train/MJSynth',], 
                              #  img_height = 64, img_width = 256,transform=dataset.transforms.Compose([vision.ToTensor()]), global_state='Train')
    # 用GeneratorDataset加载数据集
    train_loader = ds.GeneratorDataset(train_data_set, column_names=["image","label","label_res","label_sub","label_id"],
                                       num_parallel_workers=32,shuffle=True)
    # 按batchsize划分
    train_loader = train_loader.batch(batch_size=384)  
    
    test_data_set = cfgs.dataset_cfgs['dataset_test'](**cfgs.dataset_cfgs['dataset_test_args'])
    test_loader = ds.GeneratorDataset(test_data_set, column_names=["image","label"],num_parallel_workers=16,shuffle=False)
    test_loader = test_loader.batch(batch_size=64)
    
    return train_data_set, train_loader, test_data_set, test_loader

加载训练集,并传到model.train里

    # 加载训练集、测试集
    train_data_set, train_loader, test_data_set, test_loader = load_dataset()

    # 定义多标签损失函数
    loss = VisionLAN_Loss()     #自定义的
    # loss = nn.SoftmaxCrossEntropyWithLogits() #等效的
    
    # 定义损失网络,连接前向网络和多标签损失函数
    loss_net = CustomWithLossCell(net, loss)    
    
    # 定义Model,多标签场景下Model无需指定损失函数
    model = Model(network=loss_net, optimizer=optimizer)
    
    # 模型训练
    model.train(epoch=8, train_dataset=train_loader, callbacks=[LossMonitor()])

我们不清楚是getitem方法本身写法出现了问题,还是在哪里调用的时候传入了异常的数据类型,特别对于[None, Int64]中None的由来十分疑惑。现求助于各位专家、同行,期待百忙之中的回复,非常感谢!!!

@panshaowu
Copy link
Collaborator

panshaowu commented Jan 31, 2024

@sevennotmouse
您好,感谢您的反馈。抱歉回复较晚,不知道您是否已经解决上述问题?

The supported types of overload function getitem is: [Tuple, Slice], [List, Slice], [Tensor, Ellipsis], [Tuple, Tensor], [List, Number], [Tensor, Slice], [Dictionary, String], [Tensor, Tensor], [String, Number], [Tensor, Tuple], [Tensor, None], [Tuple, Number], [Tensor, Number], [Tensor, List]. RuntimeError: The 'getitem' operation does not support the type [None, Int64].

参考您提供的错误日志局部,该问题应该不是由自定义数据集类lmdbDataset的__getitem__方法引入,而是由于算子getitem被输入了不支持的数据类型导致的。猜测是由于您在建模网络时,使用了不符合预期的语法导致的。建议您尝试:

  1. 使用MindSpore的PyNative模式,进行单步调试,定位引入该问题的代码;
import mindspore as ms
ms.set_context(mode=ms.PYNATIVE_MODE)
  1. 或者,对网络建模代码进行单元测试,定位引入该问题的代码。

此外,在开发自定义数据集类的过程中,如怀疑存在问题,您可以尝试对lmdbDataset类进行单元测试,确认该类的行为是否符合预期。MindOCR项目中也实现了一些自定义数据集类(包括类似的lmdbDataset类),相关代码可供参考。
建议您尝试使用新的MindSpore r2.2.11,由于MindSpore版本升级时,部分API的行为可能变更,建议您参考官网的技术文档。

@panshaowu panshaowu self-assigned this Feb 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants