본문 바로가기
Pytorch

[Pytorch] 구글 드라이브에서 pretrained 모델(pth) torchvision 다운폴더로 받아서 실행하기.

by pulluper 2023. 5. 23.
반응형

안녕하세요 pulluper 입니다. 

 

최근에 개발을 하다가 제가 학습시켜놓은 pth 파일을 torchvision처럼 다운받아서 바로 돌리도록 코드를 작성했습니다.

 

import os
import torch
import gdown


def download_pretrained_model(pth_name, file_id=None):
    google_path = 'https://drive.google.com/uc?id='
    if file_id is None:
        return None
    torch_dir = torch.hub.get_dir()
    output_name = pth_name
    if os.path.exists(os.path.join(torch_dir, 'checkpoints', output_name)):
        print("Already downloads!")
    else:
        gdown.download(google_path+file_id,
                       os.path.join(torch_dir, 'checkpoints', output_name),
                       quiet=False)

 

pth_name 은 다운받을 이름이고(e.g.vgg16_custom_model.pth), file_id 는 구글드라이브 파일 아이디 입니다. 

 

위와 같은 코드에 file_id (공유된 구글드라이브의 파일 아이디) 를 넣어주면, 다음과 같이 

 

(윈도우 기준) "C:\Users\유저이름\.cache\torch\hub\checkpoints" 에 다른 torchvision pth 파일이

 

있는곳에 자신의 pth 파일이 다운이 됩니다. 

 

현재 제 torch.hub.get_dir() + "\checkpoints" 폴더에 들어있는 파일은 다음과 같습니다. 

 

 

이제 구글드라이브의 pretrained pth 파일을 다운받을 수 있게 되엇고 이를 load를 시켜서 실행을 시킬 수 있습니다. 

 

감사합니다. 😊😊😊

반응형

댓글