Reasoning in Large Language Models
Published:
Let’s start this blog with a task. We have to train a model which concatenates the last letters of 2 input words. For example, if the input words are ‘Elon’ and ‘Musk’, the model should return ‘nk’. If we use supervised learning to train said model, we will need many examples with variation of words containing different end letters to create a model which gives the correct output. One might argue that we can use few shot learning with LLMs like GPT-3 to solve this problem. However, the model still isn’t able to produce the right output.
This brings us to a turning point wherein we want large language models to use few shot examples but solve tasks correctly. Humans typically learn from few examples and solve new problems using reasoning. Typically, these skills are picked up in early stages of development through illustrations and a whole lot of examples. Perhaps then the solution is to teach LLMs to reason more like children - with carefully chosen training data to develop the capacity for logic and deduction.
Defining reasoning
Reasoning revolves around finding a logical and systematic method of doing some task, typically involving valid arguments and sensible conclusions. We can further deconstruct reasoning; Deductive reasoning derives conclusions from the truth of the premise. Inductive reasoning tries to draw conclusions based on some observation. Abductive reasoning involves making conclusions based on best explanation from a set of observations. Further, we can compare two events to conclude reason using analogical reasoning.
Let’s deep dive into some of the methods which adopt this idea to improve and elicit reasoning in large language models.
Fully supervised finetuning
These methods are typically employed for small language models to generate rationales that explain model predictions using a rationale/reasoning dataset. For eg Hendrycks [1]] finetuned pretrained language models to solve math problems by generating the step-by-step solutions.
However, this method requires an explicit dataset which contains the reasoning steps and is tuned to a specific dataset, restricting generalization.
Reasoning through prompting
As language models grow larger (100 B parameters), emergent properties like reasoning have been observed[2]. However, it is not implicit and needs to be prompted in LLMs. Chain of thought (CoT) [3] reasoning emulates the human thought process and prompts the intermediate reasoning steps using natural language.
The basic structure of the prompt for CoT is <input, chain of thought, output>
instead of the standard <input, output>
prompt used in few shot learning. Since few shot examples are used to guide the model to generate intermediate reasoning steps, it does not require a large dataset or modification of the model’s weights.
Zero shot CoT
One problem with few shot CoT is that we will need an initial dataset which consists of reasoning steps for different types of reasoning tasks. Kojima [3] proposes a method to perform zero -shot CoT by using instructions like “Let’s think step by step” to prompt the model to generate the reasoning chains on its own.
Reasoning through Code generation
Some LLMs trained on code like Codex [4] [5] are able to achieve better performance on reasoning tasks using CoT by framing reasoning as a code generation problem.
Rationale Exploration
Chain of thought prompting follows a very greedy decoding strategy. But complex reasoning tasks typically have multiple reasoning paths and can reach the right answer, however diverse their arguments. Wang [6] proposes self-consistency, which samples the language model’s decoder to generate a diverse set of reasoning paths by introducing an additional latent variable $r_i$, representing the sequence of tokens in the reasoning path. This path leads to a final answer $a_i$. Multiple such $(r_i, a_i)$ are sampled and a marginalization is applied over $r_i$ by taking a majority vote over $a_i$, giving the most consistent answer.
CoT with self-consistency is a very powerful strategy to elicit reasoning in complex tasks, especially arithmetic reasoning.
Problem Decomposition
CoT as a standalone prompting procedure can still struggle with reasoning tasks which involve multi-hop chains. Like humans, LLMs can perform better at complex problems using problem decomposition.
Least to most prompting
This method [7] decomposes the base problems into multiple subproblems and applies CoT individually to it, solving subproblems in a particular order.
Discussion and thoughts
Reasoning seems to be an emergent ability of LLMs, scaling well as the parameter size of the LLM increases [2]. Moreover, ‘spoon-feeding’ the model with sample or instruction enabled reasoning chains using CoT elicits reasoning in LLMs. This shows that LLMs present very human-like conditioning when it comes to dealing with reasoning tasks, which can open directions of research to apply other human-like methods for reasoning.
While significant progress is made in ‘eliciting’ reasoning in LLMs, it might be more helpful to understand how the emergent ability to reason in LLMs appears and what is the extent of this ability. Until then, happing prompting!