본문 바로가기

딥러닝

밑바닥부터 시작하는 딥러닝 5장 -7,8 - 덥셈, 곱셈 노드의 역전파(텐써플로우 포함)

■ 7. 덧셈노드의 역전파 

아래는 덧셈노드의 순전파 그림이다.  이 그림의 역전파를 알아보자 

그림 5-9. z=x+y

역전파를 보면 

상류층 L에서 미분되어 역전파로 흐르고 있다. aL/az( z에 대한 L의 미분) 

여기서 그림 5-9 의 부분(덧셈노드)의 역전파를 살펴보면 합성함수 미분을 생각하면 된다.

상류층 미분 * 덧셈노드 미분 

덧셈노드 미분은 z=x+y를 각각 x에 대해 미분, y에 대해 미분을 하니까 1이 나온다.

그래서 그림과 같이 

덧셈노드에서 x로 흐르는 값과 y 로 흐르는 값이 똑같이 aL/az*1 이 된다.

 

덧셈노드 순전파와 역전파를 텐써플로우 코드 

class AddLayer:
    def _init_(self):
        pass
    def forward(self,x,y):   # 순전파 
        out= x+y
        return out
    
    def backward(self,dout):  #역전파  # dout : 상류층 미분한 값 (aL/az) 
        dx = dout * 1 # 1은 x
        dy = dout * 1
        return dx, dy

덧셈노드 아래 그림의 값으로 객체화 시켜서 순전파와 역전파를 구해보자(객체화 시키기)

add_layer = AddLayer()
add_layer.forward(10,5)


add_layer = AddLayer()
add_layer.backward(1.3)

객체화 시킨 코드이다. 

 

■ 곱셈노드의 역전파 p.157

곱셈노드의 역전파를 살펴보면 

z=x*y 이므로 

위에서와 같이 상류층 노드와 합성함수 미분을 해야하므로 

상류층 노드 미분 (aL/az)* 현재 곱셈 노드 미분 값이다.

각각 현재 곱셈노드를 살펴보면 

z=x*y 이므로 

x에 대한 미분은 y    --> (aL/az)*y

y에 대한 미분은 x 이므로 -->(aL/az)*x 

 

그러므로 곱셈노드의 역전파는  상류에서 흘러왔던 값에 상대편 쪽의 값을 곱해줘야한다. 

 

 

곱셈노드 순전파와 역전파 함수 클래스를 텐써플로우 코드 

class  MulLayer:
    def  __init__(self):   #  클래스를 가지고 객체를 생성할 때 바로 작동되는 함수
        self.x = None     # 사과 가격이 잘못 계산될 수 있으므로 처음에 
        self.y = None     # x 와 y 를 둘다 None 으로 초기화 합니다. 

    def  forward(self, x, y):  # 순전파 함수 
        self.x = x          #  x 에 사과 가격이 입력됩니다. 
        self.y = y          # y 에는 사과 갯수가 입력됩니다.
        out = x* y       # 계산해서
        return  out       # 리턴합니다. 

    def  backward(self, dout) :  # 역전파 함수
        dx = dout * self.y
        dy = dout * self.x

        return  dx, dy

객체화 시켜보기 (우선 순전파만 구해보았다.) 

#1. 층 4개를 구성합니다.
apple_price_layer = MulLayer()
orange_price_layer = MulLayer()
add_apple_orange_layer = AddLayer()
mul_tax_layer = MulLayer()

# 2. 순전파 구현

apple_price = apple_price_layer.forward( 100, 2 )
orange_price = orange_price_layer.forward( 150, 3 )
all_price = add_apple_orange_layer.forward( apple_price, orange_price )
price = mul_tax_layer.forward( all_price, 1.1)
print(price)  # 715.0000000000001