Julie의 Tech 블로그

Secret of Long Context Length 본문

Tech/ML, DL

Secret of Long Context Length

Julie's tech 2023. 11. 25. 16:51
728x90

배경

요즈음 등장하는 LLM의 context length는 점점 더 길어지고 있다. 누가 더 window가 긴 AI 모델을 개발하느냐가 관심을 끌고 있다. 하지만 Transformer 모델의 구조상 길이가 길어질수록 연산 복잡도가 증가함에 따라 성능이 떨어지는 점과 속도가 뒤처지는 것은 피할 수 없는 문제이다. 그럼 긴 context window를 갖고 있는 AI 모델은 어떤 비결이 있는걸까?

 

문제제기

input token n 의 값이 커지면

- attention layer의 time & space complexity가 높아진다. (quadratic)

- embedding size d가 커짐으로써 embedding layer의 time & space complexity가 높아진다. (quadratic)

- positional sinusoidal embedding에서도 문제가 발생한다.

 

사실상 2K context length로 학습된 Transformer 모델도 100K 정도든 어느 양의 token을 넣을 수도 있지만 의미있는 결과를 만들어내지는 않는다. 그럼 그냥 단순하게 긴 context length로 학습하게 되면 비용이 기하급수적으로 증가한다. 현실적인 방법은 적당한 길이의 context로 학습한 후 fine-tuning하는 것이 방법이지만 이마저도 positional sinusoidal encoding으로 인해 original Transformer 모델 구조를 유지할 수 없어 모델 아키텍쳐에 변형이 필요하다.

 

요즈음의 사람들은 아래와 같은 방법을 취하여 context length 한계를 극복하려고 한다:

  • 문서를 요약시켜 프롬프트에 잘 chaining하여 넣어줌
  • 벡터 DB를 보유하여 embedding값으로 변환한 뒤 유사도 metric으로 접근
  • LLM 모델을 내 데이터로 fine-tuning함
  • 특정 데이터에 맞춘 작은 LLM 모델들을 개발함

하지만 model weight의 변경 없이 기억 영역에 잠시 올려두는 것은 큰 의미가 없기 때문에 large context window를 추구하게 된다. 긴 context window를 보유하게 되면 정확도 측면에서 훨씬 높아지기 때문이다. 결국 비용과 연산속도를 절감하는 방향으로 context window 크기를 늘릴 수 밖에 없다.

 

그럼 어떻게 저렴하면서도 연산복잡도가 낮은 방향으로 context length가 큰 모델을 개발할 수 있을까?

 

Tricks

  • [1] positional sinusoidal encoding 대신 다른 positional encoding 알고리즘을 활용

positional encoding은 언어모델에서 꽤 중요하다. 토큰의 위치에 대한 정보를 담고 있기 때문에 sequence order를 알려주는 유일한 존재인데 이 값이 적절하게 매핑이 안된다면 모델은 인풋 토큰을 모두 각기 프로세싱할 것이고 문장을 이해하기 어려워지게된다.

앞서 작은 context window 데이터로 학습하고 긴 토큰 데이터로 fine-tuning하게 될 때 positional sinusoidal encoding을 그대로 사용하기는 어렵다는 이야기가 있었다. 그 이유는 이 인코딩 기법은 extrapolation 기능이 없기 때문에 추론할 때 context length가 길어지면 성능이 떨어지기 때문이다.

positional sinusoidal encoding을 제하게 된다면 다른 방법 중 하나로는 ALiBI(Attention with Linear Biases)가 있다. 이 position embedding 방식은 attention head layer에서 사용되는데 query-key attention score가 softmax를 취하기 전에 거리에 따라 패널티를 부여하는 방식으로 score 값에 변형을 준다. 이 방식은 학습시간을 단축시키는 효과가 있다. 왜냐하면 attention head layer에서 상수(constant)를 더하기 때문에 모델 학습 사이클에서 학습 대상이 되지는 않기 때문이다.

Llama2 논문에서도 비슷한 이야기를 한다. PE대신 RoPE를 사용하는데 이 때 새로운 하이퍼파라미터 base frequency b를 도입하여 먼 거리에 있는 토큰의 decaying effect를 줄이는 데 사용함. addition이 아니라 rotation의 방식이다. 이는 긴 context 데이터로 pretraining한 뒤 context window를 넓히는 것보다 PE 알고리즘 자체를 바꾸어 성능 차이에 큰 변화를 만들어냈다고 한다.

 

  • [2] sparse attention score를 사용하여 속도를 높임

모든 토큰들이 서로 연관이 있진 않기 때문에 연산 횟수를 줄이기 위해서는 attention score를 계산할 때 고려할 토큰의 수를 줄이는 것이다. 이 방식이 궁극적으로 하고자 하는 것은 sparsity를 추구하여 연산 복잡도가 증가할 방향이 quadratic하지 않고 linear하고자 하는 것이다.

그 중 하나는 sliding window attention이라고 하여 window 크기를 지정하고 각 토큰별 그 window 내부에 있는 토큰들간의 attention pattern만을 고려하는 것이다. 이 방식을 local attention이라고 일컫기도 하는데, 이를 포함한 global, random 방식까지 혼용해서 attention score를 계산하는 BigBird attention 메커니즘도 있다. 이처럼 학습과 추론 속도를 높여 시간을 절약하는 시도도 있다.

 

  • [3] flash attention으로 학습과 추론에 걸리는 속도를 높임

이 방식은 연산속도를 높이기 위해 GPU에서의 메모리 hierarchy상 가장 빠른 SRAM에 attention score계산에 필요한 K, Q, V matrix들을 올려 연산하는 것을 제안한다. 소위 tiling이라고 하여 K, Q, V를 여러 블락 단위로 나눈 뒤 이것을 SRAM에 올려 attention을 계산한다. GPU 자체도 matrix 연산에 최적화되어있긴 하지만 이 방식은 더 나아가 attention layer 자체를 GPU 연산 방식에 맞추어 최적화한 시도로 볼 수 있다.

 

  • [4] multi-head attention이 아니라 multi-query attention

Multi-head attention 방식에서는 K, V matrix가 매 head마다 linear layer가 별도로 존재한다. 그래서 추론할 때 decorder에서는 이전 토큰의 key와 value값을 캐싱해두어 재연산할 필요를 덜게 되는데, 이 때 각 토큰마다의 GPU 메모리 사용량이 늘어나게 된다. Multi-query attention은 attention head들 사이에 weight를 서로 공유하여 K, V를 각각 linear projection할 때 2개의 matrix만을 보유해도 가능하도록 변형한 방식이다.

이 방식은 context length가 길어질수록 더 효과를 보는데 추론할 동안에 attention score 연산에 소요되는 시간을 줄이기 때문이다.

 

  • [5] conditional computation

모종의 조건제 계산방식을 도입해서 인풋 문장의 토큰들에 모델의 모든 파라미터들을 적용하는 것을 회피하려한다. 예를 들어 CoLT5 논문에서는 heavy와 light로 attention 계산방식을 두 가지로 나누었는데, light layer는 모든 토큰에 대해 attention score를 계산하고, heavy layer는 중요한 토큰에 대해서만 attention score를 계산한다. 이 방식은 정확도 뿐만 아니라 속도까지도 높은 방식이라고 소개된다.

 

  • [6] large RAM GPUs

그냥 당연한 사실을 적어두었다. GPU RAM사이즈가 커야만 large context를 대상으로 모델을 학습할 수 있기 때문이다.

 

결론

결국 input context length가 길어지면 연산의 복잡도가 올라가서 그를 줄이고자 하는 방법이 여러 가지로 제안되고 있는데, 그 중에 가장 약점으로 꼽고 있는 곳이 positional encoding 부분인 것으로 보인다. Extrapolation 기능이 없다보니 input이 길어지게 되면 추론 성능이 많이 떨어지는 것이 이슈가 되어 알고리즘 자체를 교체하는 방법이 다양한 것 같다. 그 외에는 layer별 연산 개수를 줄이는 방법들인데 과거에도 유사하게 적용되었던 방식을 많이 차용하는 것으로 보인다. window를 만들어보거나 layer를 나누고 혹은 sparse하게 만들어보는 등.

반응형