Skip to content

Warning

Documentation here is a work in progress

Translating Natural Language to BlendSQL

nl_to_blendsql

Takes a natural language question, and attempts to parse BlendSQL representation for answering against a databse.

Parameters:

Name Type Description Default
question str

The natural language question to parse

required
db Database

Database to use in translating

required
model Model

BlendSQL model to use in translating the question

required
ingredients Optional[Collection[Type[Ingredient]]]

Which ingredients to treat as valid in the output parse. Only these ingredient descriptions are included in the system prompt.

required
few_shot_examples Union[str, FewShot]

String prompt introducing few shot nl-to-blendsql examples.

''
args Optional[NLtoBlendSQLArgs]

Optional NLtoBlendSQLArgs object, containing additional parameters.

None
verbose bool

Boolean defining whether to run in logger mode

False

Returns:

Name Type Description
ret_prediction str

Final BlendSQL query prediction

Examples:

from blendsql import LLMMap, LLMQA
from blendsql.models import TransformersLLM, OllamaLLM
from blendsql.nl_to_blendsql import nl_to_blendsql, NLtoBlendSQLArgs
from blendsql.db import SQLite
from blendsql.utils import fetch_from_hub
from blendsql.prompts import FewShot

db = SQLite(
    fetch_from_hub("1884_New_Zealand_rugby_union_tour_of_New_South_Wales_1.db")
)
parser_model = OllamaLLM("phi3", caching=False)
correction_model = TransformersLLM("Qwen/Qwen1.5-0.5B")

ingredients = {LLMMap, LLMQA}
filtered_few_shot = FewShot.hybridqa.filter(ingredients)

blendsql = nl_to_blendsql(
    "What was the result of the game played 120 miles west of Sydney?",
    db=db,
    model=parser_model,
    correction_model=correction_model,
    ingredients=ingredients,
    few_shot_examples=filtered_few_shot,
    verbose=True,
    args=NLtoBlendSQLArgs(
        max_grammar_corrections=5,
        use_tables=["w"],
        include_db_content_tables=["w"],
        num_serialized_rows=3,
        use_bridge_encoder=True,
    ),
)
Source code in blendsql/nl_to_blendsql/nl_to_blendsql.py
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
def nl_to_blendsql(
    question: str,
    db: Database,
    model: Model,
    ingredients: Optional[Collection[Type[Ingredient]]],
    correction_model: Optional[Model] = None,
    few_shot_examples: Union[str, FewShot] = "",
    args: Optional[NLtoBlendSQLArgs] = None,
    verbose: bool = False,
) -> str:
    """Takes a natural language question, and attempts to parse BlendSQL representation for answering against a databse.

    Args:
        question: The natural language question to parse
        db: Database to use in translating
        model: BlendSQL model to use in translating the question
        ingredients: Which ingredients to treat as valid in the output parse.
            Only these ingredient descriptions are included in the system prompt.
        few_shot_examples: String prompt introducing few shot nl-to-blendsql examples.
        args: Optional NLtoBlendSQLArgs object, containing additional parameters.
        verbose: Boolean defining whether to run in logger mode

    Returns:
        ret_prediction: Final BlendSQL query prediction

    Examples:
        ```python
        from blendsql import LLMMap, LLMQA
        from blendsql.models import TransformersLLM, OllamaLLM
        from blendsql.nl_to_blendsql import nl_to_blendsql, NLtoBlendSQLArgs
        from blendsql.db import SQLite
        from blendsql.utils import fetch_from_hub
        from blendsql.prompts import FewShot

        db = SQLite(
            fetch_from_hub("1884_New_Zealand_rugby_union_tour_of_New_South_Wales_1.db")
        )
        parser_model = OllamaLLM("phi3", caching=False)
        correction_model = TransformersLLM("Qwen/Qwen1.5-0.5B")

        ingredients = {LLMMap, LLMQA}
        filtered_few_shot = FewShot.hybridqa.filter(ingredients)

        blendsql = nl_to_blendsql(
            "What was the result of the game played 120 miles west of Sydney?",
            db=db,
            model=parser_model,
            correction_model=correction_model,
            ingredients=ingredients,
            few_shot_examples=filtered_few_shot,
            verbose=True,
            args=NLtoBlendSQLArgs(
                max_grammar_corrections=5,
                use_tables=["w"],
                include_db_content_tables=["w"],
                num_serialized_rows=3,
                use_bridge_encoder=True,
            ),
        )
        ```
    """
    if verbose:
        logger.setLevel(logging.DEBUG)
    else:
        logger.setLevel(logging.ERROR)
    if args is None:
        args = NLtoBlendSQLArgs()
    if correction_model is None:
        correction_model = model
    parser: EarleyParser = load_cfg_parser(ingredients)
    system_prompt: str = create_system_prompt(
        ingredients=ingredients, few_shot_examples=few_shot_examples
    )
    serialized_db = db.to_serialized(
        use_tables=args.use_tables,
        num_rows=args.num_serialized_rows,
        include_content=args.include_db_content_tables,
        use_bridge_encoder=args.use_bridge_encoder,
        question=question,
    )
    if args.max_grammar_corrections == 0:
        return model.predict(
            program=ParserProgram,
            system_prompt=system_prompt,
            question=question,
            serialized_db=serialized_db,
            stream=verbose,
        )
    num_correction_left = args.max_grammar_corrections
    partial_program_prediction = ""
    ret_prediction, initial_prediction = None, None
    while num_correction_left > 0 and ret_prediction is None:
        residual_program_prediction = model.predict(
            program=ParserProgram,
            system_prompt=system_prompt,
            question=question,
            serialized_db=serialized_db,
            stream=verbose,
        )

        # if the prediction is empty, return the initial prediction
        if initial_prediction is None:
            initial_prediction = residual_program_prediction
        program_prediction = (
            partial_program_prediction + " " + residual_program_prediction
        )

        if validate_program(program_prediction, parser):
            ret_prediction = program_prediction
            continue

        # find the max score from a list of score
        prefix, candidates, pos_in_stream = obtain_correction_pairs(
            program_prediction, parser
        )
        # candidates = [i for i in candidates if i.strip() != ""]
        if len(candidates) == 0:
            logger.debug(
                Fore.LIGHTMAGENTA_EX + "No correction pairs found" + Fore.RESET
            )
            return prefix
        elif len(candidates) == 1:
            # If we only have 1 candidate, no need to call LLM
            selected_candidate = candidates.pop()
        else:
            # Generate the continuation candidate with the highest probability
            selected_candidate = correction_model.predict(
                program=CorrectionProgram,
                system_prompt=system_prompt,
                question=question,
                serialized_db=serialized_db,
                partial_completion=prefix,
                candidates=candidates,
            )

        # Try to use our selected candidate in a few ways
        # 1) Insert our selection into the index where the error occurred, and add left/right context
        #   Example: SELECT a b FROM table -> SELECT a, b FROM table
        inserted_candidate = (
            prefix + selected_candidate + program_prediction[pos_in_stream:]
        )
        if validate_program(inserted_candidate, parser):
            ret_prediction = inserted_candidate
            continue
        # 2) If rest of our query is also broken, we just keep up to the prefix + candidate
        partial_program_prediction = prefix + selected_candidate
        for p in {inserted_candidate, partial_program_prediction}:
            if validate_program(p, parser):
                ret_prediction = p

        num_correction_left -= 1

    if ret_prediction is None:
        logger.debug(
            Fore.RED
            + f"cannot find a valid prediction after {args.max_grammar_corrections} retries"
            + Fore.RESET
        )
        ret_prediction = initial_prediction
    ret_prediction = post_process_blendsql(
        ret_prediction, db, use_tables=args.use_tables
    )
    logger.debug(Fore.GREEN + ret_prediction + Fore.RESET)
    return ret_prediction

NLtoBlendSQLArgs

Source code in blendsql/nl_to_blendsql/args.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
@dataclass
class NLtoBlendSQLArgs:
    max_grammar_corrections: int = field(
        default=0,
        metadata={
            "help": "Optional int defining maximum CFG-guided correction steps to be taken. This is based on the method in https://arxiv.org/pdf/2305.19234."
        },
    )

    include_db_content_tables: Union[List[str], str] = field(
        default="all",
        metadata={
            "help": "Which database tables to add `num_serialized_rows` worth of content for in serialization."
        },
    )

    num_serialized_rows: int = field(
        default=3,
        metadata={
            "help": "How many example rows to include in serialization of database"
        },
    )

    use_tables: Collection[str] = field(
        default=None,
        metadata={"help": "Collection of tables to use in serialization to string"},
    )

    use_bridge_encoder: bool = field(
        default=True,
        metadata={
            "help": "Whether to use Bridge Content Encoder during input serialization"
        },
    )

Grammar-Based Correction

If you use the grammar correction feature of BlendSQL, please cite the original grammar prompting paper below.

@article{wang2024grammar,
  title={Grammar prompting for domain-specific language generation with large language models},
  author={Wang, Bailin and Wang, Zi and Wang, Xuezhi and Cao, Yuan and A Saurous, Rif and Kim, Yoon},
  journal={Advances in Neural Information Processing Systems},
  volume={36},
  year={2024}
}

FewShot

A collection of few-shot examples, with some utility functions for easy manipulation.

Examples:

from blendsql import LLMMap, LLMQA
from blendsql.prompts import FewShot, Examples
# Fetch the examples for HybridQA
fewshot_prompts: Examples = FewShot.hybridqa
print(f"We have {len(fewshot_prompts)} examples")
# We can select a subset by indexing
first_three_examples = fewshot_prompts[:3]
# Additionally, we can filter to keep only those examples using specified ingredients
filtered_fewshot = fewshot_prompts.filter({LLMQA, LLMMap})
Source code in blendsql/prompts/_prompts.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
@dataclass
class FewShot:
    """A collection of few-shot examples, with some utility functions for easy manipulation.

    Examples:
        ```python
        from blendsql import LLMMap, LLMQA
        from blendsql.prompts import FewShot, Examples
        # Fetch the examples for HybridQA
        fewshot_prompts: Examples = FewShot.hybridqa
        print(f"We have {len(fewshot_prompts)} examples")
        # We can select a subset by indexing
        first_three_examples = fewshot_prompts[:3]
        # Additionally, we can filter to keep only those examples using specified ingredients
        filtered_fewshot = fewshot_prompts.filter({LLMQA, LLMMap})
        ```
    """

    hybridqa = Examples(open(Path(__file__).parent / "./few_shot/hybridqa.txt").read())

Examples

Class for holding few-shot examples.

Examples:

from blendsql.prompts import FewShot, Examples
fewshot_prompts: Examples = FewShot.hybridqa
print(fewshot_prompts[:2])
Examples:

This is the first example

---

This is the second example

Source code in blendsql/prompts/_prompts.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
@attrs
class Examples:
    """Class for holding few-shot examples.

    Examples:
        ```python
        from blendsql.prompts import FewShot, Examples
        fewshot_prompts: Examples = FewShot.hybridqa
        print(fewshot_prompts[:2])
        ```
        ```text
        Examples:

        This is the first example

        ---

        This is the second example
        ```
    """

    data: str = attrib()

    split_data: List[str] = attrib(init=False)

    def __attrs_post_init__(self):
        self.data = self.data.strip()
        self.split_data: list = self.data.split("---")

    def __getitem__(self, subscript):
        newline = (
            "\n\n"
            if (isinstance(subscript, int) and subscript == 0)
            or (isinstance(subscript, slice) and subscript.start in {0, None})
            else ""
        )
        return "Examples:" + newline + "---".join(self.split_data[subscript])

    def __repr__(self):
        return "Examples:\n\n" + self.data

    def __str__(self):
        return "Examples:\n\n" + self.data

    def __len__(self):
        return len(self.split_data)

    def is_valid_query(self, query: str, ingredient_names: Set[str]) -> bool:
        """Checks if a given query is valid given the ingredient_names passed.
        A query is invalid if it includes an ingredient that is not specified in ingredient_names.
        """
        stack = [query]
        while len(stack) > 0:
            for res, _start, _end in peg_grammar.scanString(stack.pop()):
                if res.get("function").upper() not in ingredient_names:
                    return False
                for arg in res.get("args"):
                    stack.append(arg)
        return True

    def filter(self, ingredients: Iterable[Type[Ingredient]]) -> "Examples":
        """Retrieve only those prompts which do not include any ingredient not specified in `ingredients`."""
        ingredient_names: Set[str] = {
            ingredient.__name__.upper() for ingredient in ingredients
        }
        filtered_split_data = []
        for d in self.split_data:
            if self.is_valid_query(d, ingredient_names=ingredient_names):
                filtered_split_data.append(d)
        return Examples("---".join(filtered_split_data))