본문 바로가기

Development/OCR

EasyOCR 사용자 모델 학습하기 (6) - 한글 학습데이터 생성, 학습 및 테스트

728x90
반응형

출처: https://www.jaided.ai/

 

이번에는 한글 학습데이터를 생성하고 학습, 검증하는 과정을 진행해 보고자 한다. 이전에 작성한 학습데이터 생성부터 변환, 미세조정(Fine-tune) 학습 등에 대한 내용은 아래 포스트를 참고하기 바란다.

 

 

[Development/OCR] - EasyOCR 사용자 모델 학습하기 (1) - 시작하기 전에

 

[Development/OCR] - EasyOCR 사용자 모델 학습하기 (2) - 학습데이터 생성

 

[Development/OCR] - EasyOCR 사용자 모델 학습하기 (3) - 학습데이터 변환

 

[Development/OCR] - EasyOCR 사용자 모델 학습하기 (4) - 모델 학습

 

[Development/OCR] - EasyOCR 사용자 모델 학습하기 (5) - 모델 적용 및 테스트

 


0. 시작하기 전에

 

0.1 학습 설정

학습데이터 생성, 변환 및 학습에 필요한 설정을 다음과 같이 정의했다.

 

  • 학습클래스: EasyOCR 프로젝트 './easyocr/config.py' 파일의 'korean_g2' > 'characters' 사용 
  • 학습데이터셋: Training 10,000 / Validation 1,000 / Test 1,000
  • 미세조정(Fine-tune) 학습을 위한 Pre-trained 모델: EasyOCR 프로젝트의 'korean_g2.pth'
  • 학습모델의 모듈 조합: EasyOCR 프로젝트의 'korean_g2.pth'와 같은 'None-VGG-BiLSTM-CTC'

 

0.2 작업 디렉토리 구조

각 단계에서 필요로 하는 입/출력 데이터가와 설정 파일들이 저장될 작업 디렉토리의 구조는 다음과 같이 정의했다.

 

/workspace
├── /step1
│   ├── /training
│   │   └── /kordata
│   │       # TRDG 프로젝트를 통해 생성된 한글 학습데이터
│   │       # [gt]_[idx].[ext]
│   │       ├── 가_00001.jpg
│   │       ├── 나_00002.jpg
│   │       ├── 다_00003.jpg
│   │       └── ...
│   ├── /validation
│   └── /test
│
├── /step2
│   ├── /training
│   │   └── /kordata
│   │       # TRDG2DTRB 프로젝트를 통해 변환된 한글 학습데이터
│   │       ├── gt.txt
│   │       └── /images
│   │           #    image_[idx].[ext]
│   │           ├── image_00001.png
│   │           ├── image_00001.png
│   │           ├── image_00001.png
│   │           └── ...
│   ├── /validation
│   └── /test
│
├── /step3
│   ├── /training
│   │   └── /kordata
│   │       # DTRB 프로젝트의 'create_lmdb_dataset.py'를 통해 lmdb 형태로 변환된 한글 학습데이터
│   │       ├── data.lmdb
│   │       └── data.lmdb
│   ├── /validation
│   └── /test
│
├── /pre_trained_model
│   # DTRB 프로젝트에서 사용하게 될 Pre-trained 모델파일의 경로
│   └── korean_g2.pth
│
├── /user_network_dir
│   # EasyOCR 프로젝트에서 사용하게 될 사용자 모델 및 설정 파일의 저장 경로
│   ├── custom.pth
│   ├── custom.py
│   └── custom.yaml
│
└── /demo_images
    # 학습한 모델의 성능을 확인하기 위한 테스트 이미지
    ├── demo_01.png
    ├── demo_02.png
    └── ...

 

다만, deep-text-recognition-benchmark 프로젝트의 학습 결과가 저장되는 경로를 변경하려면 소스코드의 여러 부분을 수정해야 하기에, 이 부분은 기본값으로 주어진 './saved_models/'에 저장되도록 두었으니 참고하기 바란다.

 

 

1. 학습데이터 생성

TextRecognitionDataGenerator 프로젝트를 이용해 한글 학습데이터를 생성한다.

자세한 학습데이터 생성 과정은 아래 포스트에서 확인하기 바란다.

 

[Development/OCR] - EasyOCR 사용자 모델 학습하기 (1) - 시작하기 전에

 

EasyOCR 사용자 모델 학습하기 (1) - 시작하기 전에

이번에는 EasyOCR에서 제공하는 API를 통해 OCR 기능을 이용할 때 사용되는 기본 신경망 모델이 아닌, 사용자가 직접 학습시키고자 하는 데이터를 준비해 학습하고, 원하는 성능의 모델을 만들어 사

davelogs.tistory.com

 

다음의 명령구문으로 각각 training / validation / test 데이터셋을 생성한다.

 

# TextRecognitionDataGenerator 프로젝트를 이용해 한글 학습데이터 생성하기
# TextRecognitionDataGenerator 프로젝트 root에서 실행

# training: 10,000개
(venv) $ python3 ./trdg/run.py \
        --output_dir "./workspace/step1/training/kordata" \
        --language "ko" \
        --count 10000
    
# validation: 1,000개
(venv) $ python3 ./trdg/run.py \
        --output_dir "./workspace/step1/validation/kordata" \
        --language "ko" \
        --count 1000

# test: 1,000개
(venv) $ python3 ./trdg/run.py \
        --output_dir "./workspace/step1/test/kordata" \
        --language "ko" \
        --count 1000

 

 

2. 학습데이터 변환

TRDG2DTRB 프로젝트를 이용해 학습데이터를 변환한다.

자세한 학습데이터 변환 과정은 아래 포스트에서 확인하기 바란다.

 

[Development/OCR] - EasyOCR 사용자 모델 학습하기 (3) - 학습데이터 변환

 

EasyOCR 사용자 모델 학습하기 (3) - 학습데이터 변환

이전 포스트에서 TextRecognitionDataGenerator 프로젝트를 이용해 생성한 학습데이터는 학습을 위한 deep-text-recognition-benchmark 프로젝트에서 요구하는 데이터 구조는 아니었고, 또한 바로 사용할 수 없었

davelogs.tistory.com

 

다음의 명령 구문으로 각각 training / validation / test 데이터셋을 변환한다.

 

# TRDG2DTRB 프로젝트를 이용해 한글 학습데이터 변환하기
# TRDG2DTRB 프로젝트 root에서 실행

# train 학습데이터 변환
(venv) $ python3 convert.py \
        --input_path "./workspace/step1/training/kordata" \
        --output_path "./workspace/step2/training"

# validation 학습데이터 변환
(venv) $ python3 convert.py \
        --input_path "./workspace/step1/validation/kordata" \
        --output_path "./workspace/step2/validation/kordata"

# test 학습데이터 변환
(venv) $ python3 convert.py \
        --input_path "./workspace/step1/test/kordata" \
        --output_path "./workspace/step2/test/kordata"

 

 

3. 모델 학습

deep-text-recognition-benchmark 프로젝트를 이용해 사용자 모델을 학습한다.

자세한 모델 학습 과정은 아래 포스트에서 확인하기 바란다.

 

[Development/OCR] - EasyOCR 사용자 모델 학습하기 (4) - 모델 학습

 

EasyOCR 사용자 모델 학습하기 (4) - 모델 학습

신경망 모델 학습하기에 앞서 학습에 필요한 학습데이터 생성 및 변환 등에 대한 내용은 이전 포스트를 참고하기 바란다. [Development/OCR] - EasyOCR 사용자 모델 학습하기 (1) - 시작하기 전에 [Developme

davelogs.tistory.com

 

3.1 학습용 데이터 포맷 lmdb 형태로 변환

다음 구문으로 실제 학습에서 사용할 lmdb 포맷으로 학습데이터를 변환한다.

 

# deep-text-recognition-benchmark 프로젝트에서 사용할 학습데이터 포맷으로의 변환
# deep-text-recognition-benchmark 프로젝트 root에서 실행

# training 데이터 변환
(venv) $ create_lmdb_dataset.py \
        --gtFile "./workspace/step2/training/kordata/gt.txt" \
        --inputPath "./workspace/step2/training/kordata" \
        --outputPath "./workspace/step3/training/kordata"

# validation 데이터 변환
(venv) $ create_lmdb_dataset.py \
        --gtFile "./workspace/step2/validation/kordata/gt.txt" \
        --inputPath "./workspace/step2/validation/kordata" \
        --outputPath "./workspace/step3/validation/kordata"

# test 데이터 변환
(venv) $ create_lmdb_dataset.py \
        --gtFile "./workspace/step2/test/kordata/gt.txt" \
        --inputPath "./workspace/step2/test/kordata" \
        --outputPath "./workspace/step3/test/kordata"

 

 

3.2 한글 클래스 적용

deep-text-recognition-benchmark 프로젝트의 'train.py' 파일을 열고 다음 구문을 삽입한다.

 

# deep-text-recognition-benchmark 프로젝트 'train.py'의 285라인에 아래 코드를 삽입한다.
# 아래 character는 EasyOCR 프로젝트 './easyocr/config.py'의 'korean_g2' > 'character'에 해당한다.

opt.character = " !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~가각간갇갈감갑값강갖같갚갛개객걀거걱건걷걸검겁것겉게겨격겪견결겹경곁계고곡곤곧골곰곱곳공과관광괜괴굉교구국군굳굴굵굶굽궁권귀규균그극근글긁금급긋긍기긴길김깅깊까깎깐깔깜깝깥깨꺼꺾껍껏껑께껴꼬꼭꼴꼼꼽꽂꽃꽉꽤꾸꿀꿈뀌끄끈끊끌끓끔끗끝끼낌나낙낚난날낡남납낫낭낮낯낱낳내냄냉냐냥너넉널넓넘넣네넥넷녀녁년념녕노녹논놀놈농높놓놔뇌뇨누눈눕뉘뉴늄느늑는늘늙능늦늬니닐님다닥닦단닫달닭닮담답닷당닿대댁댐더덕던덜덤덥덧덩덮데델도독돈돌돕동돼되된두둑둘둠둡둥뒤뒷드득든듣들듬듭듯등디딩딪따딱딴딸땀땅때땜떠떡떤떨떻떼또똑뚜뚫뚱뛰뜨뜩뜯뜰뜻띄라락란람랍랑랗래랜램랫략량러럭런럴럼럽럿렁렇레렉렌려력련렬렵령례로록론롬롭롯료루룩룹룻뤄류륙률륭르른름릇릎리릭린림립릿마막만많말맑맘맙맛망맞맡맣매맥맨맵맺머먹먼멀멈멋멍멎메멘멩며면멸명몇모목몰몸몹못몽묘무묵묶문묻물뭄뭇뭐뭣므미민믿밀밉밌및밑바박밖반받발밝밟밤밥방밭배백뱀뱃뱉버번벌범법벗베벤벼벽변별볍병볕보복볶본볼봄봇봉뵈뵙부북분불붉붐붓붕붙뷰브블비빌빗빚빛빠빨빵빼뺨뻐뻔뻗뼈뽑뿌뿐쁘쁨사삭산살삶삼상새색샌생서석섞선설섬섭섯성세센셈셋션소속손솔솜솟송솥쇄쇠쇼수숙순술숨숫숲쉬쉽슈스슨슬슴습슷승시식신싣실싫심십싱싶싸싹쌀쌍쌓써썩썰썹쎄쏘쏟쑤쓰쓸씀씌씨씩씬씹씻아악안앉않알앓암압앗앙앞애액야약얇양얗얘어억언얹얻얼엄업없엇엉엌엎에엔엘여역연열엷염엽엿영옆예옛오옥온올옮옳옷와완왕왜왠외왼요욕용우욱운울움웃웅워원월웨웬위윗유육율으윽은을음응의이익인일읽잃임입잇있잊잎자작잔잖잘잠잡장잦재쟁저적전절젊점접젓정젖제젠젯져조족존졸좀좁종좋좌죄주죽준줄줌줍중쥐즈즉즌즐즘증지직진질짐집짓징짙짚짜짝짧째쨌쩌쩍쩐쪽쫓쭈쭉찌찍찢차착찬찮찰참창찾채책챔챙처척천철첫청체쳐초촉촌총촬최추축춘출춤춥춧충취츠측츰층치칙친칠침칭카칸칼캐캠커컨컬컴컵컷켓켜코콜콤콩쾌쿠퀴크큰클큼키킬타탁탄탈탑탓탕태택탤터턱털텅테텍텔템토톤톱통퇴투툼퉁튀튜트특튼튿틀틈티틱팀팅파팎판팔패팩팬퍼퍽페펴편펼평폐포폭표푸푹풀품풍퓨프플픔피픽필핏핑하학한할함합항해핵핸햄햇행향허헌험헤헬혀현혈협형혜호혹혼홀홍화확환활황회획횟효후훈훌훔훨휘휴흉흐흑흔흘흙흡흥흩희흰히힘"

 

위와 같이 코드를 수정하는 걸 원치 않는다면, 학습 구문을 실행할 때 '--character' 파라미터에 전달해도 무방하다.

 

3.3 EasyOCR의 한글 모델 다운로드

deep-text-recognition-benchmark 프로젝트에서는 한글을 학습한 Pre-trained 모델을 제공하지 않기 때문에, EasyOCR 프로젝트에서 사용 중인 모델을 미세조정(Fine-tune) 학습을 위한 Pre-trained 모델로 사용하고자 한다. 해당 모델은 아래 EasyOCR Model Hub에서 다운로드할 수 있다.

 

https://www.jaided.ai/easyocr/modelhub/

 

Jaided AI: EasyOCR model hub

Lost your password? Please enter your email address. You will receive a link to create a new password. Back to log-in

www.jaided.ai

 

EasyOCR 프로젝트의 Model Hub에서 'korean_g2'를 다운로드한다.

 

다운로드한 파일은 './workspace/pre_trained_model/' 에 저장한다.

 

3.4 모델 학습

이제 학습을 위한 준비는 모두 마쳤다. 다음 구문을 통해 학습을 진행하자.

단, 이전과 다르게 '--input_channel', '--output_channel', '--hidden_size' 옵션을 추가로 설정한 것은 EasyOCR 프로젝트의 'korean_g2' 모델을 Pre-trained 모델로 사용하기 위함이다.

 

# deep-text-recognition-benchmark 프로젝트를 이용한 모델 학습
# deep-text-recognition-benchmark 프로젝트 root에서 실행
(venv) $ python3 train.py \
        --train_data "./workspace/step3/training" \
        --valid_data "./workspace/step3/validation" \
        --select_data / \
        --batch_ratio 1 \
        --Transformation None \
        --FeatureExtraction "VGG" \
        --SequenceModeling "BiLSTM" \
        --Prediction "CTC" \
        --input_channel 1 \
        --output_channel 256 \
        --hidden_size 256 \
        --saved_model "./workspace/pre_trained_model/korean_g2.pth" \
        --FT

# character 적용을 위한 소스코드 수정을 피하고 싶은 경우
(venv) $ python3 train.py \
        --train_data "./workspace/step3/training" \
        --valid_data "./workspace/step3/validation" \
        --select_data / \
        --batch_ratio 1 \
        --Transformation None \
        --FeatureExtraction "VGG" \
        --SequenceModeling "BiLSTM" \
        --Prediction "CTC" \
        --input_channel 1 \
        --output_channel 256 \
        --hidden_size 256 \
        --character " !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~가각간갇갈감갑값강갖같갚갛개객걀거걱건걷걸검겁것겉게겨격겪견결겹경곁계고곡곤곧골곰곱곳공과관광괜괴굉교구국군굳굴굵굶굽궁권귀규균그극근글긁금급긋긍기긴길김깅깊까깎깐깔깜깝깥깨꺼꺾껍껏껑께껴꼬꼭꼴꼼꼽꽂꽃꽉꽤꾸꿀꿈뀌끄끈끊끌끓끔끗끝끼낌나낙낚난날낡남납낫낭낮낯낱낳내냄냉냐냥너넉널넓넘넣네넥넷녀녁년념녕노녹논놀놈농높놓놔뇌뇨누눈눕뉘뉴늄느늑는늘늙능늦늬니닐님다닥닦단닫달닭닮담답닷당닿대댁댐더덕던덜덤덥덧덩덮데델도독돈돌돕동돼되된두둑둘둠둡둥뒤뒷드득든듣들듬듭듯등디딩딪따딱딴딸땀땅때땜떠떡떤떨떻떼또똑뚜뚫뚱뛰뜨뜩뜯뜰뜻띄라락란람랍랑랗래랜램랫략량러럭런럴럼럽럿렁렇레렉렌려력련렬렵령례로록론롬롭롯료루룩룹룻뤄류륙률륭르른름릇릎리릭린림립릿마막만많말맑맘맙맛망맞맡맣매맥맨맵맺머먹먼멀멈멋멍멎메멘멩며면멸명몇모목몰몸몹못몽묘무묵묶문묻물뭄뭇뭐뭣므미민믿밀밉밌및밑바박밖반받발밝밟밤밥방밭배백뱀뱃뱉버번벌범법벗베벤벼벽변별볍병볕보복볶본볼봄봇봉뵈뵙부북분불붉붐붓붕붙뷰브블비빌빗빚빛빠빨빵빼뺨뻐뻔뻗뼈뽑뿌뿐쁘쁨사삭산살삶삼상새색샌생서석섞선설섬섭섯성세센셈셋션소속손솔솜솟송솥쇄쇠쇼수숙순술숨숫숲쉬쉽슈스슨슬슴습슷승시식신싣실싫심십싱싶싸싹쌀쌍쌓써썩썰썹쎄쏘쏟쑤쓰쓸씀씌씨씩씬씹씻아악안앉않알앓암압앗앙앞애액야약얇양얗얘어억언얹얻얼엄업없엇엉엌엎에엔엘여역연열엷염엽엿영옆예옛오옥온올옮옳옷와완왕왜왠외왼요욕용우욱운울움웃웅워원월웨웬위윗유육율으윽은을음응의이익인일읽잃임입잇있잊잎자작잔잖잘잠잡장잦재쟁저적전절젊점접젓정젖제젠젯져조족존졸좀좁종좋좌죄주죽준줄줌줍중쥐즈즉즌즐즘증지직진질짐집짓징짙짚짜짝짧째쨌쩌쩍쩐쪽쫓쭈쭉찌찍찢차착찬찮찰참창찾채책챔챙처척천철첫청체쳐초촉촌총촬최추축춘출춤춥춧충취츠측츰층치칙친칠침칭카칸칼캐캠커컨컬컴컵컷켓켜코콜콤콩쾌쿠퀴크큰클큼키킬타탁탄탈탑탓탕태택탤터턱털텅테텍텔템토톤톱통퇴투툼퉁튀튜트특튼튿틀틈티틱팀팅파팎판팔패팩팬퍼퍽페펴편펼평폐포폭표푸푹풀품풍퓨프플픔피픽필핏핑하학한할함합항해핵핸햄햇행향허헌험헤헬혀현혈협형혜호혹혼홀홍화확환활황회획횟효후훈훌훔훨휘휴흉흐흑흔흘흙흡흥흩희흰히힘"
        --saved_model "./workspace/pre_trained_model/korean_g2.pth" \
        --FT

 

3.5 모델 학습 결과 확인

deep-text-recognition-benchmark 프로젝트를 통해 모델의 학습이 진행되는 동안 터미널을 통해 확인할 수도 있지만, './saved_models/None-VGG-BiLSTM-CTC-Seed1111' 디렉토리에서 학습로그를 확인할 수도 있다. 해당 디렉토리에서 확인할 수 있는 학습로그는 다음과 같다.

 

모델 학습로그 'opt.txt', 'log_train.txt', 'log_dataset.txt'

 

 

728x90

 

4. 모델 적용 및 테스트

EasyOCR 프로젝트를 이용해 위에서 학습한 사용자 모델을 적용, 테스트한다.

자세한 사용자 학습 모델 적용 과정은 아래 포스트에서 확인하기 바란다.

 

[Development/OCR] - EasyOCR 사용자 모델 학습하기 (5) - 모델 적용 및 테스트

 

EasyOCR 사용자 모델 학습하기 (5) - 모델 적용 및 테스트

마지막으로, EasyOCR에서 사용자 모델을 사용할 수 있는 방법을 기술하기에 앞서, 학습에 필요한 학습데이터 생성, 변환 및 미세조정(Fine-tune) 학습 등에 대한 내용은 이전 포스트를 참고하기 바란

davelogs.tistory.com

 

 

4.1 사용자 학습 모델 복사

위에서 정의한 것처럼, 학습한 모델을 './workspace/user_network_dir'로 복사한다. 여기서는 임의로 'best_accuracy.pth' 파일을 복사해 'custom.pth'로 파일명을 변경했다.

 

4.2 'custom.yaml' 파일 생성

다음의 구문을 그대로 사용해서 파일을 생성하거나, 아래 첨부파일을 다운로드해서 사용하면 된다.

 

# korean_g2 (None-VGG-BiLSTM-CTC)
imgH: 32
lang_list: ['ko']
network_params:
  input_channel: 1
  output_channel: 256
  hidden_size: 256
character_list: " !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~가각간갇갈감갑값강갖같갚갛개객걀거걱건걷걸검겁것겉게겨격겪견결겹경곁계고곡곤곧골곰곱곳공과관광괜괴굉교구국군굳굴굵굶굽궁권귀규균그극근글긁금급긋긍기긴길김깅깊까깎깐깔깜깝깥깨꺼꺾껍껏껑께껴꼬꼭꼴꼼꼽꽂꽃꽉꽤꾸꿀꿈뀌끄끈끊끌끓끔끗끝끼낌나낙낚난날낡남납낫낭낮낯낱낳내냄냉냐냥너넉널넓넘넣네넥넷녀녁년념녕노녹논놀놈농높놓놔뇌뇨누눈눕뉘뉴늄느늑는늘늙능늦늬니닐님다닥닦단닫달닭닮담답닷당닿대댁댐더덕던덜덤덥덧덩덮데델도독돈돌돕동돼되된두둑둘둠둡둥뒤뒷드득든듣들듬듭듯등디딩딪따딱딴딸땀땅때땜떠떡떤떨떻떼또똑뚜뚫뚱뛰뜨뜩뜯뜰뜻띄라락란람랍랑랗래랜램랫략량러럭런럴럼럽럿렁렇레렉렌려력련렬렵령례로록론롬롭롯료루룩룹룻뤄류륙률륭르른름릇릎리릭린림립릿마막만많말맑맘맙맛망맞맡맣매맥맨맵맺머먹먼멀멈멋멍멎메멘멩며면멸명몇모목몰몸몹못몽묘무묵묶문묻물뭄뭇뭐뭣므미민믿밀밉밌및밑바박밖반받발밝밟밤밥방밭배백뱀뱃뱉버번벌범법벗베벤벼벽변별볍병볕보복볶본볼봄봇봉뵈뵙부북분불붉붐붓붕붙뷰브블비빌빗빚빛빠빨빵빼뺨뻐뻔뻗뼈뽑뿌뿐쁘쁨사삭산살삶삼상새색샌생서석섞선설섬섭섯성세센셈셋션소속손솔솜솟송솥쇄쇠쇼수숙순술숨숫숲쉬쉽슈스슨슬슴습슷승시식신싣실싫심십싱싶싸싹쌀쌍쌓써썩썰썹쎄쏘쏟쑤쓰쓸씀씌씨씩씬씹씻아악안앉않알앓암압앗앙앞애액야약얇양얗얘어억언얹얻얼엄업없엇엉엌엎에엔엘여역연열엷염엽엿영옆예옛오옥온올옮옳옷와완왕왜왠외왼요욕용우욱운울움웃웅워원월웨웬위윗유육율으윽은을음응의이익인일읽잃임입잇있잊잎자작잔잖잘잠잡장잦재쟁저적전절젊점접젓정젖제젠젯져조족존졸좀좁종좋좌죄주죽준줄줌줍중쥐즈즉즌즐즘증지직진질짐집짓징짙짚짜짝짧째쨌쩌쩍쩐쪽쫓쭈쭉찌찍찢차착찬찮찰참창찾채책챔챙처척천철첫청체쳐초촉촌총촬최추축춘출춤춥춧충취츠측츰층치칙친칠침칭카칸칼캐캠커컨컬컴컵컷켓켜코콜콤콩쾌쿠퀴크큰클큼키킬타탁탄탈탑탓탕태택탤터턱털텅테텍텔템토톤톱통퇴투툼퉁튀튜트특튼튿틀틈티틱팀팅파팎판팔패팩팬퍼퍽페펴편펼평폐포폭표푸푹풀품풍퓨프플픔피픽필핏핑하학한할함합항해핵핸햄햇행향허헌험헤헬혀현혈협형혜호혹혼홀홍화확환활황회획횟효후훈훌훔훨휘휴흉흐흑흔흘흙흡흥흩희흰히힘"

 

custom.yaml
0.00MB

 

4.3. 'custom.py' 모듈 파일 생성

역시 이전 포스트에서 했던 것처럼, 다음의 구문을 그대로 사용해 파일로 저장하거나, 아래 첨부파일을 다운로드해서 사용하면 된다.

 

import torch.nn as nn


class Model(nn.Module):

    def __init__(self, input_channel, output_channel, hidden_size, num_class):
        super(Model, self).__init__()
        """ FeatureExtraction """
        self.FeatureExtraction = VGG_FeatureExtractor(input_channel, output_channel)
        self.FeatureExtraction_output = output_channel
        self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1))

        """ Sequence modeling"""
        self.SequenceModeling = nn.Sequential(
            BidirectionalLSTM(self.FeatureExtraction_output, hidden_size, hidden_size),
            BidirectionalLSTM(hidden_size, hidden_size, hidden_size))
        self.SequenceModeling_output = hidden_size

        """ Prediction """
        self.Prediction = nn.Linear(self.SequenceModeling_output, num_class)

    def forward(self, input, text):
        """ Feature extraction stage """
        visual_feature = self.FeatureExtraction(input)
        visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0, 3, 1, 2))
        visual_feature = visual_feature.squeeze(3)

        """ Sequence modeling stage """
        contextual_feature = self.SequenceModeling(visual_feature)

        """ Prediction stage """
        prediction = self.Prediction(contextual_feature.contiguous())

        return prediction


class BidirectionalLSTM(nn.Module):

    def __init__(self, input_size, hidden_size, output_size):
        super(BidirectionalLSTM, self).__init__()
        self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True)
        self.linear = nn.Linear(hidden_size * 2, output_size)

    def forward(self, input):
        """
        input : visual feature [batch_size x T x input_size]
        output : contextual feature [batch_size x T x output_size]
        """
        try: # multi gpu needs this
            self.rnn.flatten_parameters()
        except: # quantization doesn't work with this
            pass
        recurrent, _ = self.rnn(input)  # batch_size x T x input_size -> batch_size x T x (2*hidden_size)
        output = self.linear(recurrent)  # batch_size x T x output_size
        return output


class VGG_FeatureExtractor(nn.Module):

    def __init__(self, input_channel, output_channel=256):
        super(VGG_FeatureExtractor, self).__init__()
        self.output_channel = [int(output_channel / 8), int(output_channel / 4),
                               int(output_channel / 2), output_channel]
        self.ConvNet = nn.Sequential(
            nn.Conv2d(input_channel, self.output_channel[0], 3, 1, 1), nn.ReLU(True),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(self.output_channel[0], self.output_channel[1], 3, 1, 1), nn.ReLU(True),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(self.output_channel[1], self.output_channel[2], 3, 1, 1), nn.ReLU(True),
            nn.Conv2d(self.output_channel[2], self.output_channel[2], 3, 1, 1), nn.ReLU(True),
            nn.MaxPool2d((2, 1), (2, 1)),
            nn.Conv2d(self.output_channel[2], self.output_channel[3], 3, 1, 1, bias=False),
            nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True),
            nn.Conv2d(self.output_channel[3], self.output_channel[3], 3, 1, 1, bias=False),
            nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True),
            nn.MaxPool2d((2, 1), (2, 1)),
            nn.Conv2d(self.output_channel[3], self.output_channel[3], 2, 1, 0), nn.ReLU(True))

    def forward(self, input):
        return self.ConvNet(input)

 

custom.py
0.00MB

 

 

4.4 실행 테스트

역시 이전 포스트에서 작성했던 것처럼, 아래 코드를 그대로 사용하거나, 필요한 경우 아래 첨부파일을 다운로드해서 사용하면 된다. 단, 아래 코드는 EasyOCR 프로젝트 root에서 시작되는 코드이니 주의를 바란다.

 

from easyocr.easyocr import *

# GPU 설정
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'


def get_files(path):
    file_list = []

    files = [f for f in os.listdir(path) if not f.startswith('.')]  # skip hidden file
    files.sort()
    abspath = os.path.abspath(path)
    for file in files:
        file_path = os.path.join(abspath, file)
        file_list.append(file_path)

    return file_list, len(file_list)


if __name__ == '__main__':

    # # Using default model
    # reader = Reader(['ko'], gpu=True)

    # Using custom model
    reader = Reader(['ko'], gpu=True,
                    model_storage_directory='./workspace/user_network_dir',
                    user_network_directory='./workspace/user_network_dir',
                    recog_network='custom')

    files, count = get_files('./workspace/demo_images')

    for idx, file in enumerate(files):
        filename = os.path.basename(file)

        result = reader.readtext(file)

        # ./easyocr/utils.py 733 lines
        # result[0]: bbox
        # result[1]: string
        # result[2]: confidence
        for (bbox, string, confidence) in result:
            print("filename: '%s', confidence: %.4f, string: '%s'" % (filename, confidence, string))
            # print('bbox: ', bbox)

 

run.py
0.00MB

 

실행하는 방법은 위 코드를 PyCharm과 같은 통합IDE 환경에서 실행해도 되고, 아래 명령 구문처럼 콘솔에서 직접 실행해도 된다.

 

# EasyOCR 사용자 모델 사용해 실행하기
(venv) $ python3 run.py

 

실행결과는 다음과 같다.

 

EasyOCR의 기본 모델과 사용자 모델의 실행 결과 비교

 

5. 결론

위 결과만 보면 그래도 불만족스럽긴 하지만, 문자(character) 단위 학습데이터와 단일 폰트만을 이용해 학습했기에 어느 정도는 감안할 수 있는 수준이라고 생각된다. 모두가 아는 것처럼, 결론적으로 양질의 학습데이터가 모델 성능을 좌우하니, 충분한 학습데이터를 확보해 학습한다면 보다 나은 수준의 성능을 기대할 수 있지 않을까 생각한다. 이제 모델의 성능 향상을 위한 튜닝 및 고도화 과정이 시작된 것이다.

 

'EasyOCR 사용자 모델 학습하기' 시리즈는 EasyOCR 프로젝트에서 사용자 모델을 사용할 수 있고, 또 해당 모델을 기반으로 미세조정이 가능하다는 점 때문에 본 포스팅을 시작했다. 하지만, 포스팅을 진행하다 보니 EasyOCR 프로젝트보다는 deep-text-recognition-benchmark 프로젝트에 좀 더 흥미를 느꼈고, 포스팅의 내용도 deep-text-recognition-benchmark 프로젝트 위주로 작성한 게 아닌가 싶긴 하다.

 

어쨌든, 이제 양질의 학습데이터를 확보한 후, 보다 다양한 방법과 다양한 모델들을 학습해 가면서 최상의 성능을 보여주는 모델을 찾아야 할 것 같다.

 


 

여기까지, 'EasyOCR 사용자 모델 학습하기' 시리즈는 마무리한다.

 

다음은 EasyOCR과 유사한 오픈소스 기반의 유명한 OCR 엔진인 Tesseract OCR 모델 학습 및 사용 방법에 대한 내용은 아래 링크를 참고하기 바란다.

 

[Development/OCR] - Tesseract OCR 4.x 모델 학습하기 (1)

 

Tesseract OCR 4.x 모델 학습하기 (1)

Tesseract에서 제공하는 API를 통해 OCR 기능을 이용할 때 사용되는 기본 학습모델이 아닌 사용자가 직접 학습시키고자 하는 데이터를 준비해 학습하고 원하는 성능의 모델을 만들어 사용할 수 있는

davelogs.tistory.com

 

 

 

728x90
반응형