Clase 8: Redes Recurrentes
Hasta ahora, hemos asumido que los datos con los que trabajamos son independientes e idénticamente distribuidos (i.i.d). Sin embargo, en muchos casos, los datos tienen una estructura secuencial que debe ser considerada al momento de modelarlos. Algunos ejemplos comunes de datos secuenciales incluyen:
También pudiesen existir datos “multimodales”, donde por ejemplo, se combinan secuencias con imágenes.
Supongamos el siguiente ejemplo:
Nos gustaría poder entrenar un modelo capaz de trabajar con datos de entrada de distintas longitudes (secuencias de tamaño variable).
Queremos que el modelo sea capaz de utilizar información pasada para realizar predicciones sobre valores futuros.
Pros
Cons
Nomenclatura
🚨 Una RNN tiene dos set de parámetros: \(W_{ih}\) y \(b_{ih}\), los cuales representan la transformación entre los valores de entrada y el estado oculto, y \(W_{hh}\) y \(b_{hh}\) los que representan la transformación entre el estado oculto previo y el actual (Feedback Loop).
👀 Es importante mencionar que las RNN son aplicadas a la secuencia elemento a elemento.
rnn = nn.RNN(input_size=1, hidden_size=2, num_layers=1, batch_first=True)
def forward_pass_rnn(x):
N, L, D = x.shape
h=[0]
for seq_id in range(L):
h.append(torch.tanh(x[:,seq_id,:]*rnn.weight_ih_l0 + rnn.bias_ih_l0 + h[-1]*rnn.weight_hh_l0 + rnn.bias_hh_l0))
return htensor([[0.4727, 0.1731]],
[[ 0.9106, -0.9335]],
[[ 0.7067, -0.9780]])
Shape: (1,3,2)
🚨 Básicamente la RNN genera una transformación affine por cada time step \(t\) agregando información de su memoria pasada. Es decir, cada elemento de la secuencia es transformado y llevado a un número de dimensiones igual a hidden_size.
☝️ Nomenclatura en Pytorch
📢 Unrolling de una RNN
El unrolling consiste en representar una RNN como una secuencia virtual de capas conectadas que permite ver la relación entre los elementos de cada time step. Es importante notar que los parámetros se comparten, es decir, cada una de las capas tiene los mismos pesos y bias.
vanishing gradient problem y el exploding gradient problem.\[Gradiente = f(Input \times W_2^{N_{Unroll}})\]
Cuando los valores de \(W_2\) son muy pequeños (menores que 1), el gradiente tiende a desvanecerse (vanishing gradient).
En cambio, si los valores de \(W_2\) son muy grandes (mayores que 1), el gradiente tiende a explotar (exploding gradient).
$W_2^{N_unroll} aparece en la ecuación al momento de comenzar a derivar de manera recursiva.
Las Vanilla RNNs se utilizan muy poco en la práctica; sin embargo, tienen una relevancia histórica significativa, ya que sentaron las bases para el desarrollo de arquitecturas más avanzadas, como las LSTMs y los Transformers.
✨✨ Part of Speech Tagging: Entrego una Secuencia de Largo L y obtengo una Secuencia de Largo L con Clases Asociadas.
class POSTaggingRNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.rnn = nn.RNN(
input_size=input_size,
hidden_size=hidden_size,
num_layers=1,
batch_first=True,
)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
rnn_out, hn = self.rnn(x)
print("Tamaño del Output de RNN: ", hn.shape)
logits = self.fc(rnn_out)
return logits
model = POSTaggingRNN(input_size=1,
hidden_size=4,
output_size=3)
output = model(x1)
print("Shape del Output Final: ", output.shape)
outputTamaño del Output de RNN: torch.Size([1, 1, 4])
Shape del Output Final: torch.Size([1, 3, 3])
tensor([[[-0.6526, 0.2722, 0.2755],
[-1.1120, 0.7727, 0.4798],
[-1.0416, 0.7447, 0.4793]]])
✨✨ Sentiment Analysis: Entrego una Secuencia de Largo L y obtengo una Clase Final (Positiva, Negativa, Neutra).
class SentimentAnalysisRNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.rnn = nn.RNN(
input_size=input_size,
hidden_size=hidden_size,
num_layers=1,
batch_first=True,
)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
rnn_out, hn = self.rnn(x)
print("Tamaño del Hidden State del RNN: ", hn.shape)
logits = self.fc(hn)
return logits
model = SentimentAnalysisRNN(input_size=1,
hidden_size=4,
output_size=3)
output = torch.softmax(model(x1), dim=-1)
print("Shape del Output Final: ", output.shape)
outputTamaño del Hidden State del RNN: torch.Size([1, 1, 4])
Shape del Output Final: torch.Size([1, 1, 3])
tensor([[[0.0866, 0.5169, 0.3965]]])
Es posible juntar varias capas recurrentes, para que las salidas de una alimenten un siguiente Hidden State, y que luego de algunas capas efectivamente se llegue a las salidas de interés.
Debido a que hacer esto es complicado esto viene integrado en la implementación en Pytorch mediante el parámetro num_layers.
OJO
No existen salidas intermedias, sino que los Hidden States de capas anteriores son utilizados directamente como inputs de los hidden states posteriores.
En Pytorch los Hidden States se devuelven concatenados. Es común utilizar el último Hidden State, es decir, la última salida de la última capa como Input Features para una capa Fully Connected.
A diferencia de otro tipos de Redes como las Convolucionales o FFN, la profundidad en este tipo de redes es de bastante menos impacto.
vanishing gradient problem.
Posee un funcionamiento similar a la RNN, sólo que el “hidden state” se divide en dos partes: \(h_t\) y \(C_t\), llamados hidden state (corto plazo) y cell state (largo plazo) respectivamente.
Spoiler: El Hidden y Cell State está compuesto por multiples set de parámetros a los cuales se les dan los nombres de forget gate, input gate, cell gate y output gate. Su interpretabilidad nunca ha logrado ser completamente explicada.
La LSTM está regida por las siguientes ecuaciones:
\[i_t = \sigma(W_{ii}x_t + b_{ii} + W_{hi}h_{t-1} + b_{hi})\]
\[f_t = \sigma(W_{if}x_t + b_{if} + W_{hf}h_{t-1} + b_{hf})\]
\[g_t = tanh(W_{ig}x_t + b_{ig} + W_{hg}h_{t-1} + b_{hg})\]
\[o_t = \sigma(W_{io}x_t + b_{io} + W_{ho}h_{t-1} + b_{ho})\]
\[c_t = f_t \odot c_{t-1} + i_t \odot g_t\]
\[h_t = o_t \odot tanh(c_t)\]
Todas estos elementos \(i_t,f_t, g_t,o_t, c_t,h_t \in \mathbb{R}^d\), donde \(d\) es el “hidden_size”.
Forget Gate
Cell State.porcentaje a olvidar.\[f_t = \sigma(W_{if}x_t + b_{if} + W_{hf}h_{t-1} + b_{hf})\]
Input Gate
Cell State.\[i_t = \sigma(W_{ii}x_t + b_{ii} + W_{hi}h_{t-1} + b_{hi})\]
Cell Gate
Cell State.\[g_t = tanh(W_{ig}x_t + b_{ig} + W_{hg}h_{t-1} + b_{hg})\]
Output Gate
Determina qué “porcentaje” de información del “Cell State” debe salir como “Hidden State” para el tiempo \(t\) actual.
\[o_t = \sigma(W_{io}x_t + b_{io} + W_{ho}h_{t-1} + b_{ho})\]
Hidden State
\[h_t = o_t \odot tanh(c_t)\]
Cell State
Representa la principal innovación de este tipo de redes ya que permite recordar dependencias de largo plazo (es decir time steps anteriores en secuencias largas). Esto ya que el Cell State puede avanzar casi sin interacciones lineales (no hay parámetros que influyen en ella, por lo que no es afectada por problemas de gradientes).
\[c_t = f_t \odot c_{t-1} + i_t \odot g_t\]
Hidden State
Representa la potencial actualización del “Hidden State”.
\[h_t = (1-z_t) \odot n_t + z_t \odot h_{t-1}\]
Update Gate
Controla qué porcentaje del “hidden state” previo se lleva al siguiente paso.
\[z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz}h_{t-1} + b_{hz})\]
Reset Gate
Controla cuánta información del pasado se debe olvidar.
\[r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr}h_{t-1} + b_{hr})\]
Candidate Hidden State
Representa la potencial actualización del “Hidden State”.
\[n_t = tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{t-1} + b_{hn}))\]
Existen ocasiones en las que se requiere no sólo el contexto de los tiempos anteriores, sino también de los posteriores. Por ejemplo, problemas de traducción.
Para ello existen las redes bidireccionales, en la cual se agrega una segunda capa pero que mueve los hidden state en el otro sentido.
nn.RNN(input_size, hidden_size, num_layers=1, batch_first=False,
dropout=0, bidirectional=False, nonlinearity="tanh")input_size.False.