Pytorch로 시작하는 딥러닝 입문(03-05. 클래스로 파이토치 모델 구현하기)

2024. 2. 21. 21:24딥러닝 모델: 파이토치

파이토치의 대부분의 구현체들은 대부분 모델을 생성할 때 클래스를 사용한다. 앞서 배운 선형회귀를 클래스로 구현해보자. 오직 클래스로 모델을 구현한다는 점이 앞서 구현한 코드와 다른 점이다. 

♣ 모델을 클래스로 구현하기 

 

앞서 구현했던 단순선형회귀모델을 클래스로 구현한다. 

 

 

위와 같이 클래스를 사용한 모델 구현 형식은 대부분의 파이토치 구현체에서 사용하고 있는 방식이다. 클래스 형태의 모델은 nn.Module을 상속받는다. 그리고 __init__( )에서 모델의 구조와 동작을 정의하는 생성자를 정의한다. 이는 파이썬에서 객체가 갖는 속성값을 초기화하는 역할로, 객체가 생성될 때 자동으로 호출된다. 

torch.nn.Linear()를 사용할 때 in_features, out_features, bias=True가 파라미터로 들어간다. in_features는 입력되는 feature의 차원 수, out_features는 출력되는 feature의 차원 수이다. bias는 bias 항을 사용할지 여부를 의미한다. bias는 명시적으로 지정하지 않더라도 True 값을 기본으로 한다. 

 

 

 

super( ) 함수를 부르면 여기서 만든 클래스는 nn.Module 클래스의 속성들을 가지고 초기화된다. 상속받은 클래스의 메서드를 호출할 수 있게 하고, 이를 통해 코드의 재사용성과 유지보수성을 높인다.

forward( ) 함수는 모델이 학습데이터를 입력받아서 forward 연산을 진행시키는 함수이다. 이 forward( ) 함수는 model 객체를 데이터와 함께 호출하면 자동으로 실행된다. 예를 들어 model이란 이름의 객체를 생성한 후 mode(입력 데이터)와 같은 형식으로 객체를 호출하면 자동으로 forward 연산이 수행된다. 

 

즉 H(x) 식을 이용하여 입력 x로부터 예측된 y를 얻는 것을 forward 연산이라고 한다. 

 

 

앞서 구현했던 다중선형회귀모델도 클래스로 구현한다. 

 

 

♣ 단순선형회귀 클래스로 구현하기

 

♣ 다중선형회귀클래스로 구현하기

 

보면 알겠지만 class를 정의한 부분만 다르고 나머지는 모두 같다.