Skip to main content

Quickstart

在本指南中,我们将讨论创建基本问答链和代理程序在 SQL 数据库上的方法。这些系统将允许我们询问有关 SQL 数据库中数据的问题,并获得自然语言回答。两者之间的主要区别在于我们的代理程序可以循环查询数据库,直到回答问题所需的次数。

⚠️ Security note ⚠️

构建 SQL 数据库问答系统需要执行模型生成的 SQL 查询。这样做存在固有风险,请确保你的数据库连接权限尽可能狭隘地限制在你的链/代理需求范围内。这将减轻建立模型驱动系统的风险,虽然无法完全消除。有关常规安全最佳实践,请参阅此处。

Architecture

在高层次上,任何SQL链和代理的步骤是:

  1. 将问题转换为 SQL 查询:Model 将用户输入转换为 SQL 查询。
  2. 执行 SQL 查询。
  3. 回答问题:模型使用查询结果响应用户输入。

sql_usecase.png

Setup

首先,获取所需的软件包并设置环境变量:

%pip install --upgrade --quiet  langchain langchain-community langchain-openai

我们将在本指南中使用一个OpenAI模型。

import getpass
import os

os.environ["OPENAI_API_KEY"] = getpass.getpass()

# Uncomment the below to use LangSmith. Not required.
# os.environ["LANGCHAIN_API_KEY"] = getpass.getpass()
# os.environ["LANGCHAIN_TRACING_V2"] = "true"

以下示例将使用与 Chinook 数据库的 SQLite 连接。请按照以下安装步骤在与此笔记本相同的目录中创建 Chinook.db。

  • 保存这个文件为 **Chinook_Sqlite.sql**
  • 运行 sqlite3 Chinook.db
  • 运行 读取 Chinook_Sqlite.sql
  • 测试 选择 * 从艺术家 限制 10;

现在,Chinhook.db位于我们的目录中,我们可以使用SQLAlchemy驱动的SQLDatabase类与之进行接口操作。

from langchain_community.utilities import SQLDatabase

db = SQLDatabase.from_uri("sqlite:///Chinook.db")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")
sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']
"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

很好!我们有一个可以查询的SQL数据库。现在让我们尝试将其连接到LLM。

Chain

让我们创建一个简单的链,将一个问题转化为一个SQL查询,执行这个查询,然后利用结果来回答最初的问题。

Convert question to SQL query

SQL 链或代理的第一步是接受用户输入并将其转换为 SQL 查询。LangChain 带有一个内置的链用于这个目的:create_sql_query_chain。

from langchain.chains import create_sql_query_chain
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
chain = create_sql_query_chain(llm, db)
response = chain.invoke({"question": "How many employees are there"})
response
'SELECT COUNT(*) FROM Employee'

我们可以执行查询以确保其有效:

db.run(response)
'[(8,)]'

我们可以查看 LangSmith 追踪来更好地了解这个链在做什么。我们也可以直接检查链中的提示。查看提示(如下),我们可以看到它是:

  • 方言特定。在这种情况下,它明确指的是SQLite。
  • 所有可用表格的定义。
  • 每个表格有三个示例行。

这种技术受到类似这篇论文的启发,该论文建议展示示例行并对表格进行明确处理可以提高性能。我们还可以像这样检查完整提示:

chain.get_prompts()[0].pretty_print()
You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today".

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here

Only use the following tables:
{table_info}

Question: {input}

Execute SQL query

现在我们已经生成了一个 SQL 查询,我们想要执行它。 这是创建 SQL 链的最危险的部分。 仔细考虑是否可以在数据上运行自动化查询。尽量将数据库连接权限最小化。考虑在查询执行之前为您的链添加一个人工审批步骤(见下文)。

我们可以使用 QuerySQLDatabaseTool 轻松地向我们的链添加查询执行:

from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool

execute_query = QuerySQLDataBaseTool(db=db)
write_query = create_sql_query_chain(llm, db)
chain = write_query | execute_query
chain.invoke({"question": "How many employees are there"})
'[(8,)]'

Answer the question

现在我们已经有了自动生成和执行查询的方法,我们只需要将原始问题和SQL查询结果合并起来生成最终答案。我们可以再次将问题和结果传递给LLM。

from operator import itemgetter

from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough

answer_prompt = PromptTemplate.from_template(
"""Given the following user question, corresponding SQL query, and SQL result, answer the user question.

Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: """
)

answer = answer_prompt | llm | StrOutputParser()
chain = (
RunnablePassthrough.assign(query=write_query).assign(
result=itemgetter("query") | execute_query
)
| answer
)

chain.invoke({"question": "How many employees are there"})
'There are 8 employees.'

Next steps

对于更复杂的查询生成,我们可能希望创建几个示范提示或添加查询检查步骤。要了解更多高级技巧,请查看:

  • 引导策略:高级提示工程技术。
  • 查询检查 :添加查询验证和错误处理。
  • 大型数据库:与大型数据库合作的技术。

Agents

LangChain有一个SQL代理,可以提供更灵活的方式与SQL数据库进行交互。使用SQL代理的主要优点有:

  • 它可以根据数据库模式和数据库内容回答问题(例如描述特定表)。
  • 它可以通过运行生成的查询,捕捉回溯并正确地重新生成来从错误中恢复。
  • 它可以回答需要多个相关查询的问题。
  • 它将通过仅考虑相关表中的架构来节省令牌。

为了初始化代理,我们使用 create_sql_agent 函数。这个代理包含 SQLDatabaseToolkit,其中包含以下工具:

  • 创建并执行查询
  • 检查查询语法
  • 检索表描述
  • … 以及更多

Initializing agent

from langchain_community.agent_toolkits import create_sql_agent

agent_executor = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True)
agent_executor.invoke(
{
"input": "List the total sales per country. Which country's customers spent the most?"
}
)


> Entering new AgentExecutor chain...

Invoking: `sql_db_list_tables` with `{}`


Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
Invoking: `sql_db_schema` with `Invoice,Customer`



CREATE TABLE "Customer" (
"CustomerId" INTEGER NOT NULL,
"FirstName" NVARCHAR(40) NOT NULL,
"LastName" NVARCHAR(20) NOT NULL,
"Company" NVARCHAR(80),
"Address" NVARCHAR(70),
"City" NVARCHAR(40),
"State" NVARCHAR(40),
"Country" NVARCHAR(40),
"PostalCode" NVARCHAR(10),
"Phone" NVARCHAR(24),
"Fax" NVARCHAR(24),
"Email" NVARCHAR(60) NOT NULL,
"SupportRepId" INTEGER,
PRIMARY KEY ("CustomerId"),
FOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")
)

/*
3 rows from Customer table:
CustomerId FirstName LastName Company Address City State Country PostalCode Phone Fax Email SupportRepId
1 Luís Gonçalves Embraer - Empresa Brasileira de Aeronáutica S.A. Av. Brigadeiro Faria Lima, 2170 São José dos Campos SP Brazil 12227-000 +55 (12) 3923-5555 +55 (12) 3923-5566 luisg@embraer.com.br 3
2 Leonie Köhler None Theodor-Heuss-Straße 34 Stuttgart None Germany 70174 +49 0711 2842222 None leonekohler@surfeu.de 5
3 François Tremblay None 1498 rue Bélanger Montréal QC Canada H2G 1A7 +1 (514) 721-4711 None ftremblay@gmail.com 3
*/


CREATE TABLE "Invoice" (
"InvoiceId" INTEGER NOT NULL,
"CustomerId" INTEGER NOT NULL,
"InvoiceDate" DATETIME NOT NULL,
"BillingAddress" NVARCHAR(70),
"BillingCity" NVARCHAR(40),
"BillingState" NVARCHAR(40),
"BillingCountry" NVARCHAR(40),
"BillingPostalCode" NVARCHAR(10),
"Total" NUMERIC(10, 2) NOT NULL,
PRIMARY KEY ("InvoiceId"),
FOREIGN KEY("CustomerId") REFERENCES "Customer" ("CustomerId")
)

/*
3 rows from Invoice table:
InvoiceId CustomerId InvoiceDate BillingAddress BillingCity BillingState BillingCountry BillingPostalCode Total
1 2 2009-01-01 00:00:00 Theodor-Heuss-Straße 34 Stuttgart None Germany 70174 1.98
2 4 2009-01-02 00:00:00 Ullevålsveien 14 Oslo None Norway 0171 3.96
3 8 2009-01-03 00:00:00 Grétrystraat 63 Brussels None Belgium 1000 5.94
*/
Invoking: `sql_db_query` with `SELECT c.Country, SUM(i.Total) AS TotalSales FROM Invoice i JOIN Customer c ON i.CustomerId = c.CustomerId GROUP BY c.Country ORDER BY TotalSales DESC LIMIT 10;`
responded: To list the total sales per country, I can query the "Invoice" and "Customer" tables. I will join these tables on the "CustomerId" column and group the results by the "BillingCountry" column. Then, I will calculate the sum of the "Total" column to get the total sales per country. Finally, I will order the results in descending order of the total sales.

Here is the SQL query:

```sql
SELECT c.Country, SUM(i.Total) AS TotalSales
FROM Invoice i
JOIN Customer c ON i.CustomerId = c.CustomerId
GROUP BY c.Country
ORDER BY TotalSales DESC
LIMIT 10;
```

Now, I will execute this query to get the total sales per country.

[('USA', 523.0600000000003), ('Canada', 303.9599999999999), ('France', 195.09999999999994), ('Brazil', 190.09999999999997), ('Germany', 156.48), ('United Kingdom', 112.85999999999999), ('Czech Republic', 90.24000000000001), ('Portugal', 77.23999999999998), ('India', 75.25999999999999), ('Chile', 46.62)]The total sales per country are as follows:

1. USA: $523.06
2. Canada: $303.96
3. France: $195.10
4. Brazil: $190.10
5. Germany: $156.48
6. United Kingdom: $112.86
7. Czech Republic: $90.24
8. Portugal: $77.24
9. India: $75.26
10. Chile: $46.62

To answer the second question, the country whose customers spent the most is the USA, with a total sales of $523.06.

> Finished chain.
{'input': "List the total sales per country. Which country's customers spent the most?",
'output': 'The total sales per country are as follows:\n\n1. USA: $523.06\n2. Canada: $303.96\n3. France: $195.10\n4. Brazil: $190.10\n5. Germany: $156.48\n6. United Kingdom: $112.86\n7. Czech Republic: $90.24\n8. Portugal: $77.24\n9. India: $75.26\n10. Chile: $46.62\n\nTo answer the second question, the country whose customers spent the most is the USA, with a total sales of $523.06.'}
agent_executor.invoke({"input": "Describe the playlisttrack table"})


> Entering new AgentExecutor chain...

Invoking: `sql_db_list_tables` with `{}`


Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
Invoking: `sql_db_schema` with `PlaylistTrack`



CREATE TABLE "PlaylistTrack" (
"PlaylistId" INTEGER NOT NULL,
"TrackId" INTEGER NOT NULL,
PRIMARY KEY ("PlaylistId", "TrackId"),
FOREIGN KEY("TrackId") REFERENCES "Track" ("TrackId"),
FOREIGN KEY("PlaylistId") REFERENCES "Playlist" ("PlaylistId")
)

/*
3 rows from PlaylistTrack table:
PlaylistId TrackId
1 3402
1 3389
1 3390
*/The `PlaylistTrack` table has two columns: `PlaylistId` and `TrackId`. It is a junction table that represents the many-to-many relationship between playlists and tracks.

Here is the schema of the `PlaylistTrack` table:

```
CREATE TABLE "PlaylistTrack" (
"PlaylistId" INTEGER NOT NULL,
"TrackId" INTEGER NOT NULL,
PRIMARY KEY ("PlaylistId", "TrackId"),
FOREIGN KEY("TrackId") REFERENCES "Track" ("TrackId"),
FOREIGN KEY("PlaylistId") REFERENCES "Playlist" ("PlaylistId")
)
```

The `PlaylistId` column is a foreign key referencing the `PlaylistId` column in the `Playlist` table. The `TrackId` column is a foreign key referencing the `TrackId` column in the `Track` table.

Here are three sample rows from the `PlaylistTrack` table:

```
PlaylistId TrackId
1 3402
1 3389
1 3390
```

Please let me know if there is anything else I can help with.

> Finished chain.
{'input': 'Describe the playlisttrack table',
'output': 'The `PlaylistTrack` table has two columns: `PlaylistId` and `TrackId`. It is a junction table that represents the many-to-many relationship between playlists and tracks. \n\nHere is the schema of the `PlaylistTrack` table:\n\n```\nCREATE TABLE "PlaylistTrack" (\n\t"PlaylistId" INTEGER NOT NULL, \n\t"TrackId" INTEGER NOT NULL, \n\tPRIMARY KEY ("PlaylistId", "TrackId"), \n\tFOREIGN KEY("TrackId") REFERENCES "Track" ("TrackId"), \n\tFOREIGN KEY("PlaylistId") REFERENCES "Playlist" ("PlaylistId")\n)\n```\n\nThe `PlaylistId` column is a foreign key referencing the `PlaylistId` column in the `Playlist` table. The `TrackId` column is a foreign key referencing the `TrackId` column in the `Track` table.\n\nHere are three sample rows from the `PlaylistTrack` table:\n\n```\nPlaylistId TrackId\n1 3402\n1 3389\n1 3390\n```\n\nPlease let me know if there is anything else I can help with.'}

Next steps

前往代理页面了解如何使用和自定义代理。


Help us out by providing feedback on this documentation page: