Clase 9: Redes Recurrentes
Hasta ahora, asumimos que todos nuestros datos son i.i.d. Pero en la realidad existen datos que contienen una secuencia temporal y que debe ser considerada al momento de modelarlos.
También pudiesen existir datos “multimodales”, donde por ejemplo, se combinan secuencias con imágenes.
En este caso consideramos que \(x_t\) corresponde a la instancia de un dato en el tiempo \(t\) asociado a un target \(y_t\). Además tenemos cierta dependencia entre los elementos en \(t\) y \(t+1\).
Es importante notar que \(x_t\) no tiene por qué ser un escalar, sino que puede ser un vector de constituidos por varios features.
Ojo
Un dato corresponde una secuencia de elementos, por lo tanto \(x = [x_1,x_2,...x_L]\), donde \(L\) es el largo de la secuencia.
Es importante recalcar que la correcta descripción de cada palabra depende del contexto en el que se está usando y no sólo la palabra en sí misma.
El contexto ayuda a interpretar cuál es la manera correcta de interpretar el sonido emitido.
donde:
\[h_t = f(W_{hh} \cdot h_{t-1} + W_{hx} \cdot x_t + b_h)\] \[y_t = g(W_{yh}\cdot h_{t} + b_y)\]
Es posible juntar varias capas recurrentes, para que las salidas de una alimenten un siguiente Hidden State, y que luego de algunas capas efectivamente se llegue a las salidas de interés.
A diferencia de otro tipos de Redes como las Convolucionales o FFN, la profundidad en este tipo de redes es de bastante menos impacto.
OJO
No existen salidas intermedias, sino que los Hidden States de capas anteriores son utilizados directamente como inputs de los hidden states posteriores.
\[Output = Input \times W_2^{N_{Unroll}}\]
Si \(W_2\) corresponde a parámetros muy pequeños (menores a 1), entonces, el gradiente se desvanecerá (vanishing gradient).
Si \(W_2\) corresponde a parámetros muy grandes (mayores a 1), entonces, el gradiente explotará (exploding gradient).
Esto ocurre ya que si intentamos derivar la función de pérdida con respecto a alguno de los parámetros, eventualmente \(W_2^{N_unroll}\) aparecerá en la ecuación, provocando dicho efecto en el gradiente.
Esta es quizás la razón más importante del por qué Vanilla RNNs
son usadas rara vez en la práctica. Lo importante histórica de este tipo de redes es que abrieron las puertas a sistemas más modernos que hoy en día sí son usados (LSTMs, y Transformers).
vanishing gradient problem
.
Corresponde a la misma forma de una RNN, sólo que el “hidden state” se divide en dos partes: \(h_t\) y \(C_t\), llamados “hidden state” y “cell state” respectivamente.
La manera de calcular el “Hidden State” y el “Cell State” es muchísimo más engorrosa.
Spoiler: El Hidden y Cell State está compuesto por multiples set de parámetros a los cuales se les dan los nombres de forget gate
, input gate
, cell gate
y output gate
. Su interpretabilidad nunca ha logrado ser completamente explicada.
La LSTM está regida por las siguientes ecuaciones:
\[i_t = \sigma(W_{ii}x_t) + b_{ii} + W_{hi}h_{t-1} + b_{hi}\]
\[f_t = \sigma(W_{if}x_t + b_{if} + W_{hf}h_{t-1} + b_{hf})\]
\[g_t = tanh(W_{ig}x_t + b_{ig} + W_{hg}h_{t-1} + b_{hg})\]
\[o_t = \sigma(W_{io}x_t + b_{io} + W_{ho}h_{t-1} + b_{ho})\]
\[c_t = f_t \odot c_{t-1} + i_t \odot g_t\]
\[h_t = o_t \odot tanh(c_t)\]
Todas estos elementos \(i_t,f_t, g_t,o_t, c_t,h_t \in \mathbb{R}^d\), donde \(d\) corresponde a la “hidden dimension”.
Forget Gate
Cell State
.porcentaje
a olvidar.\[f_t = \sigma(W_{if}x_t + b_{if} + W_{hf}h_{t-1} + b_{hf})\]
Input Gate
Cell State
.\[i_t = \sigma(W_{ii}x_t) + b_{ii} + W_{hi}h_{t-1} + b_{hi}\]
Cell Gate
Cell State
.\[g_t = tanh(W_{ig}x_t + b_{ig} + W_{hg}h_{t-1} + b_{hg})\]
Output Gate
Determina qué “porcentaje” de información del “Cell State” debe salir como “Hidden State” para el tiempo \(t\) actual.
\[o_t = \sigma(W_{io}x_t + b_{io} + W_{ho}h_{t-1} + b_{ho})\]
Hidden State
\[h_t = o_t \odot tanh(c_t)\]
Cell State
Representa la principal innovación de este tipo de redes ya que permite recordar dependencias de largo plazo (es decir time steps anteriores en secuencias largas). Esto ya que el Cell State puede avanzar casi sin interacciones lineales (no hay parámetros que influyen en ella, por lo que no es afectada por problemas de gradientes).
\[c_t = f_t \odot c_{t-1} + i_t \odot g_t\]
Hidden State
Representa la potencial actualización del “Hidden State”.
\[h_t = (1-z_t) \odot n_t + z_t \odot h_{t-1}\]
Update Gate
Controla qué porcentaje del “hidden state” previo se lleva al siguiente paso.
\[z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz}h_{t-1} + b_{hz})\]
Reset Gate
Controla cuánta información del pasado se debe olvidar.
\[r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr}h_{t-1} + b_{hr})\]
Candidate Hidden State
Representa la potencial actualización del “Hidden State”.
\[n_t = tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{t-1} + b_{hn}))\]
Existen ocasiones en las que se requiere no sólo el contexto de los tiempos anteriores, sino también de los posteriores. Por ejemplo, problemas de traducción.
Para ello existen las redes bidireccionales, en la cual se agrega una segunda capa pero que mueve los hidden state en el otro sentido.
nn.RNN(input_size, hidden_size, num_layers=1, batch_first=False,
dropout=0, bidirectional=False, nonlinearity="tanh")
input_size
.False
.