Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

When I used svtr_tiny_ch.yaml to configure training, there was a loss of nan #678

Open
zx214 opened this issue Mar 7, 2024 · 11 comments
Open
Assignees

Comments

@zx214
Copy link

zx214 commented Mar 7, 2024

When I used svtr_tiny_ch.yaml to configure training, I encountered a Loss of nan.I used the pre-trained model of svtr_tiny_ch-2ee6ade4.ckpt.I am using the GPU version of MindSpore.The display is as follows:
[WARNING] ME(3613:140537237112640,ForkServerPoolWorker-1:1):2024-03-07-06:39:44.569.244 [mindspore/train/summary/_summary_adapter.py:304] There are no valid values in the ndarray(size=40960, shape=(1, 640, 64))
[WARNING] ME(3613:140537237112640,ForkServerPoolWorker-1:1):2024-03-07-06:39:44.569.675 [mindspore/train/summary/_summary_adapter.py:304] There are no valid values in the ndarray(size=864, shape=(32, 3, 3, 3))
[WARNING] ME(3613:140537237112640,ForkServerPoolWorker-1:1):2024-03-07-06:39:44.569.892 [mindspore/train/summary/_summary_adapter.py:304] There are no valid values in the ndarray(size=32, shape=(32,))
[WARNING] ME(3613:140537237112640,ForkServerPoolWorker-1:1):2024-03-07-06:39:44.570.100 [mindspore/train/summary/_summary_adapter.py:304] There are no valid values in the ndarray(size=32, shape=(32,))
[WARNING] ME(3613:140537237112640,ForkServerPoolWorker-1:1):2024-03-07-06:39:44.570.365 [mindspore/train/summary/_summary_adapter.py:304] There are no valid values in the ndarray(size=18432, shape=(64, 32, 3, 3))
[2024-03-07 06:40:10] mindocr.utils.callbacks INFO - epoch: [1/3] step: [100/1000], loss: nan, lr: 0.000033, per step time: 879.835 ms, fps per card: 18.19 img/s
[WARNING] ME(3621:140537237112640,ForkServerPoolWorker-1:9):2024-03-07-06:40:10.246.843 [mindspore/train/summary/_summary_adapter.py:304] There are no valid values in the ndarray(size=40960, shape=(1, 640, 64))
[WARNING] ME(3621:140537237112640,ForkServerPoolWorker-1:9):2024-03-07-06:40:10.247.187 [mindspore/train/summary/_summary_adapter.py:304] There are no valid values in the ndarray(size=864, shape=(32, 3, 3, 3))
[WARNING] ME(3621:140537237112640,ForkServerPoolWorker-1:9):2024-03-07-06:40:10.247.432 [mindspore/train/summary/_summary_adapter.py:304] There are no valid values in the ndarray(size=32, shape=(32,))
[WARNING] ME(3621:140537237112640,ForkServerPoolWorker-1:9):2024-03-07-06:40:10.247.616 [mindspore/train/summary/_summary_adapter.py:304] There are no valid values in the ndarray(size=32, shape=(32,))
[WARNING] ME(3621:140537237112640,ForkServerPoolWorker-1:9):2024-03-07-06:40:10.247.852 [mindspore/train/summary/_summary_adapter.py:304] There are no valid values in the ndarray(size=18432, shape=(64, 32, 3, 3))
[2024-03-07 06:40:35] mindocr.utils.callbacks INFO - epoch: [1/3] step: [200/1000], loss: nan, lr: 0.000066, per step time: 252.681 ms, fps per card: 63.32 img/s

@horcham
Copy link
Collaborator

horcham commented Mar 7, 2024

Hello, issue 670 have a similar problem with this problem, you could refer to it.

We suggest you could try the following solutions:

  1. Check the MindSpore and CANN version is matched. You could run the following code to check it.
    import mindspore
    mindspore.set_context(device_target="Ascend")
    mindspore.run_check()
    exit()
    Please refer to https://www.mindspore.cn/install/en#configuring-environment-variables and find more details. If some errors occur in the version check, please refer to the installation guide and reinstall CANN and MindSpore.
  2. O2 amp_level might cause unsteady training. You could try to set amp_level to O0. For example, if you are using configs/rec/svtr/svtr_tiny.yaml, you could find
    system:
      ...
      amp_level: O2
      amp_level_infer: O2 # running inference in O2 mode
      ...
    
    replace O2 with O0, and the yaml file would be changed like:
    system:
      ...
      amp_level: O0
      amp_level_infer: O0
      ...
    

And would you mind offering the following information, so that we could relocate the problems faster:

  • your GPU information
  • cuda/cudnn version
  • mindspore version
  • dataset you use
  • the yaml config, if you have modified it

@zx214
Copy link
Author

zx214 commented Mar 7, 2024

My yaml config is as follows:
system:
mode: 0 # 0 for graph mode, 1 for pynative mode in MindSpore
distribute: False
amp_level: O0
amp_level_infer: O0 # running inference in O2 mode
seed: 42
log_interval: 100
val_while_train: True
drop_overflow_update: False
ckpt_max_keep: 5

common:
character_dict_path: &character_dict_path mindocr/utils/dict/ch_dict.txt
num_classes: &num_classes 6624 # num_chars_in_dict + 1
max_text_len: &max_text_len 40
use_space_char: &use_space_char True
batch_size: &batch_size 256
num_workers: &num_workers 1
num_epochs: &num_epochs 1
dataset_root: &dataset_root ./dataset_ic15/rec
ckpt_save_dir: &ckpt_save_dir ./tmp_rec
ckpt_load_path: &ckpt_load_path ./tmp_rec/best.ckpt
resume: &resume False
data_shape: (32, 320)
resume_epochs: 0
lr: 0.01

model:
type: rec
transform: null
backbone:
name: SVTRNet
pretrained: False
img_size: [32, 320]
out_channels: 96
patch_merging: Conv
embed_dim: [64, 128, 256]
depth: [3, 6, 3]
num_heads: [2, 4, 8]
mixer:
[
"Local",
"Local",
"Local",
"Local",
"Local",
"Local",
"Global",
"Global",
"Global",
"Global",
"Global",
"Global",
]
local_mixer: [[7, 11], [7, 11], [7, 11]]
last_stage: True
prenorm: False
neck:
name: Img2Seq
head:
name: CTCHead
out_channels: *num_classes
pretrained: False

postprocess:
name: RecCTCLabelDecode
character_dict_path: *character_dict_path
use_space_char: *use_space_char

metric:
name: RecMetric
main_indicator: acc
lower: False
character_dict_path: *character_dict_path
ignore_space: True
print_flag: False

loss:
name: CTCLoss
pred_seq_len: 80 # 320 / 4
max_label_len: *max_text_len # this value should be smaller than pre_seq_len
batch_size: *batch_size

scheduler:
scheduler: warmup_cosine_decay
min_lr: 0.00001
lr: 0.001
num_epochs: *num_epochs
warmup_epochs: 3
decay_epochs: 27
optimizer:
opt: adamw
grouping_strategy: svtr
filter_bias_and_bn: False
weight_decay: 0.05

loss_scaler:
type: dynamic
loss_scale: 512
scale_factor: 2.0
scale_window: 1000

train:
ckpt_save_dir: *ckpt_save_dir
dataset_sink_mode: True
ema: True
ema_decay: 0.9999
dataset:
type: RecDataset
dataset_root: *dataset_root
data_dir: Task3/train
label_file: train_rec_gt.txt
sample_ratio: 1.0
shuffle: True
filter_max_len: True
max_text_len: *max_text_len
transform_pipeline:
- DecodeImage:
img_mode: BGR
to_float32: False
- RecCTCLabelEncode:
max_text_len: *max_text_len
character_dict_path: *character_dict_path
use_space_char: *use_space_char
lower: False
- Rotate90IfVertical:
threshold: 2.0
direction: counterclockwise
- SVTRRecResizeImg:
image_shape: [32, 320]
padding: True
- NormalizeImage:
bgr_to_rgb: True
is_hwc: True
mean: [127.0, 127.0, 127.0]
std: [127.0, 127.0, 127.0]
- ToCHWImage:
output_columns: ["image", "text_seq"]
net_input_column_index: [0]
label_column_index: [1]

loader:
shuffle: True
batch_size: *batch_size
drop_remainder: True
max_rowsize: 12
num_workers: 4

eval:
ckpt_load_path: *ckpt_load_path
dataset_sink_mode: False
dataset:
type: RecDataset
dataset_root: *dataset_root
data_dir: Task3/test
label_file: test_rec_gt.txt
sample_ratio: 1.0
shuffle: False
transform_pipeline:
- DecodeImage:
img_mode: BGR
to_float32: False
- RecCTCLabelEncode:
max_text_len: *max_text_len
character_dict_path: *character_dict_path
use_space_char: *use_space_char
lower: False
- Rotate90IfVertical:
threshold: 2.0
direction: counterclockwise
- SVTRRecResizeImg:
image_shape: [32, 320] # H, W
padding: True
- NormalizeImage:
bgr_to_rgb: True
is_hwc: True
mean: [127.0, 127.0, 127.0]
std: [127.0, 127.0, 127.0]
- ToCHWImage:
output_columns: ["image", "text_padded", "text_length"]
net_input_column_index: [0]
label_column_index: [1, 2]

loader:
shuffle: False
batch_size: 64
drop_remainder: False
max_rowsize: 12
num_workers: 1

@horcham
Copy link
Collaborator

horcham commented Mar 8, 2024

Train with O0 might work. You could set the amp_level and amp_level_infer to O0, and try to train the model.

O2 amp_level might cause unsteady training. You could try to set amp_level to O0. For example, if you are using configs/rec/svtr/svtr_tiny.yaml, you could find

system:
  ...
  amp_level: O2
  amp_level_infer: O2 # running inference in O2 mode
  ...

replace O2 with O0, and the yaml file would be changed like:

system:
  ...
  amp_level: O0
  amp_level_infer: O0
  ...

@zx214
Copy link
Author

zx214 commented Mar 8, 2024

In my config,I haved set the amp_level and amp_level_infer to O0。How can I solve this problem?

@horcham
Copy link
Collaborator

horcham commented Mar 8, 2024

Can it train normally with svtr_tiny.yaml instead of svtr_tiny_ch.yaml?

@zx214
Copy link
Author

zx214 commented Mar 8, 2024

There is no problem training with English datasets and svtr_tiny.yaml.

@horcham
Copy link
Collaborator

horcham commented Mar 8, 2024

It seems you train ic15 dataset on svtr_tiny_ch, but ic15 does not contain chinese characters. How about using svtr_tiny.yaml instead?

@zx214
Copy link
Author

zx214 commented Mar 8, 2024

I used RCTW dataset on svtr_tiny_ch with svtr_tiny_ch-2ee6ade4.ckpt pretrained model.

@horcham
Copy link
Collaborator

horcham commented Mar 11, 2024

Would you mind offering us more information, so that we could locate the problems faster?

  • your GPU information
  • cuda/cudnn version
  • mindspore version

@zx214
Copy link
Author

zx214 commented Mar 12, 2024

CUDA Version: 11.4 and mindspore2.2.11

@horcham
Copy link
Collaborator

horcham commented Apr 12, 2024

In the log, the warnings occur with "here are no valid values in the ndarray", maybe some Data produce large or overflow gradient while training, and cause nan problems. To avoid this, set "drop_overflow_update: False" to "drop_overflow_update: True".

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants