This notebook illustrates the integration of Large Language Models (LLM) with the MITRE ATT&CK framework. Analysts can dynamically generate insightful and context-rich content tailored for threat intelligence, enhancing cybersecurity research and analysis.
Original notebook can be found here: https://otrf.github.io/GPT-Security-Adventures/experiments/ATTCK-GPT/notebook.html#generate-knowledge-base-embeddings
NB: The foundational knowledge and associated markdown file pertaining to the ATT&CK group were pre-generated using attackcti courtesy of @cyb3rward0g.
In this context, we leverage LangChain's modular framework to seamlessly load our ATT&CK Markdown file, setting the stage for subsequent data-driven and interactive tasks.
# Import
import os
# Define local variables
current_directory = os.path.dirname("__file__")
knowledge_directory = os.path.join(current_directory, "knowledge")
db_directory = os.path.join(current_directory, "db")
templates_directory = os.path.join(current_directory, "templates")
group_template = os.path.join(templates_directory, "group.md")
import glob
from langchain.document_loaders import UnstructuredMarkdownLoader
# Using glob to find all Markdown files in the knowledge_directory
# The "*.md" means it will look for all files ending with .md (Markdown files)
group_files = glob.glob(os.path.join(knowledge_directory, "*.md"))
# Initializing an empty list to store the content of Markdown files
md_docs = []
# Start of the Markdown file loading process
print("[+] Loading Group markdown files..")
# Loop through each Markdown file path in group_files
for group in group_files:
# print(f' [*] Loading {os.path.basename(group)}')
# Create an instance of UnstructuredMarkdownLoader to load the content of the current Markdown file
loader = UnstructuredMarkdownLoader(group)
# Load the content and extend the md_docs list with it
md_docs.extend(loader.load())
# Print the total number of Markdown documents processed
print(f'[+] Number of .md documents processed: {len(md_docs)}')
# Display one of the page content
print(md_docs[5].page_content)
Tokenization is the process of converting a sequence of text into individual units, known as "tokens." These tokens can be as small as characters or as long as words, depending on the specific requirements of the task and the language of the text. Tokenization is a crucial pre-processing step in Natural Language Processing (NLP) and text analytics applications.g models.
# Import the tiktoken library
import tiktoken
# Initialize the tokenizer for the GPT-4 model
# The function encoding_for_model returns a tokenizer configured for the specified model ('gpt-4' in this case)
tokenizer = tiktoken.encoding_for_model('gpt-4')
# Tokenize the content of the first Markdown document in the md_docs list
# The encode method converts the text into a list of integers, each representing a token
# disallowed_special=() ensures that certain special tokens are not included in the output
token_integers = tokenizer.encode(md_docs[0].page_content, disallowed_special=())
# Count the number of tokens generated
# This is useful for understanding the size of the text and for cost estimation if using OpenAI's API
num_tokens = len(token_integers)
# Decode the integer tokens back to bytes
# This is done using the decode_single_token_bytes method
# This step is optional and is generally used for debugging or analysis
token_bytes = [tokenizer.decode_single_token_bytes(token) for token in token_integers]
# Print the results
# Display the total number of tokens, the integer representation of tokens, and their byte representation
print()
print(f"token count: {num_tokens} tokens")
print(f"token integers: {token_integers}")
print(f"token bytes: {token_bytes}")
# Define a function called tiktoken_len to calculate the number of tokens in a given text
def tiktoken_len(text):
# Use the tokenizer's encode method to tokenize the input text
# The disallowed_special=() parameter ensures that special tokens are not included in the tokenization
tokens = tokenizer.encode(
text,
disallowed_special=() # To disable this check for all special tokens
)
# Return the number of tokens generated
return len(tokens)
# Create a list called token_counts to store the number of tokens for each Markdown document in md_docs
# The tiktoken_len function is used to calculate the token count for each document's content
token_counts = [tiktoken_len(doc.page_content) for doc in md_docs]
# Print the statistics related to token counts
# Calculate and display the minimum, average, and maximum number of tokens across all Markdown documents
print(f"""[+] Token Counts:
Min: {min(token_counts)} # Minimum number of tokens across all documents
Avg: {int(sum(token_counts) / len(token_counts))} # Average number of tokens across all documents
Max: {max(token_counts)} # Maximum number of tokens across all documents
""")
The goal of the "Recursively split by character" method is to split a text into smaller chunks based on a list of characters. The method tries to split the text on these characters in order until the resulting chunks are small enough. The default list of characters used for splitting is ["\n\n", "\n", " ", ""]. This method aims to keep paragraphs, sentences, and words together as much as possible, as these are typically semantically related pieces of text. The chunk size is measured by the number of characters in each chunk.
# Import the RecursiveCharacterTextSplitter class from the langchain library
from langchain.text_splitter import RecursiveCharacterTextSplitter
# Print a message indicating the initialization of RecursiveCharacterTextSplitter
print('[+] Initializing RecursiveCharacterTextSplitter..')
# Create an instance of RecursiveCharacterTextSplitter with specified parameters
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500, # Maximum number of tokens in each chunk
chunk_overlap=50, # Number of tokens that will overlap between adjacent chunks
length_function=tiktoken_len, # Function to calculate the number of tokens in a text
separators=['\n\n', '\n', ' ', ''] # List of separators used to split the text into chunks
)
print('[+] Splitting documents in chunks..')
chunks = text_splitter.split_documents(md_docs)
print(f'[+] Number of documents: {len(md_docs)}')
print(f'[+] Number of chunks: {len(chunks)}')
print(chunks[1])
What it is: Embedding is a way to convert words or phrases into numbers (vectors) so that a computer can understand and work with them.
Why it's useful: Once text is converted into numbers, it's easier to see how similar different words or sentences are, and to perform tasks like searching and classification.
What it is: FAISS is a tool developed by Facebook that helps you quickly find items that are similar to a given item, based on their numerical (vector) representation.
Why it's useful: Imagine you have a huge library of books, and you want to find the ones most similar to a particular book. FAISS helps you do this very quickly, even if your library is enormous.
What they are: A vector is just a list of numbers. In the context of embeddings and FAISS, each number in the vector represents some feature or characteristic of the text.
Why they're useful: Vectors make it easy for computers to understand and compare things. For example, the vector for the word "apple" might be closer to the vector for "fruit" than to the vector for "car," helping the computer understand that apples are more related to fruits than to cars.
So, in summary:
Embedding turns text into vectors. Vectors are lists of numbers that computers can easily work with. FAISS uses these vectors to quickly find similar items in a large dataset.
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import FAISS
import openai
import os
# Get your key: https://platform.openai.com/account/api-keys
openai.api_key = os.getenv("OPENAI_API_KEY")
print("[+] Starting embedding..")
embeddings = OpenAIEmbeddings()
# Send text chunks to OpenAI Embeddings API
print("[+] Sending chunks to OpenAI Embeddings API..")
db = FAISS.from_documents(chunks, embeddings)
retriever = db.as_retriever(search_kwargs={"k":5})
query = "What are some phishing techniques used by threat actors?"
print("[+] Getting relevant documents for query..")
relevant_docs = retriever.get_relevant_documents(query)
relevant_docs
from langchain.chains.question_answering import load_qa_chain
from langchain.llms import OpenAI
chain = load_qa_chain(OpenAI(temperature=0), chain_type="stuff")
chain.run(input_documents=relevant_docs, question=query)
import ipywidgets as widgets
from ipywidgets import interact_manual, Layout
text_layout = Layout(
width='80%', # Set the width to 80% of the container
height='50px', # Set the height
)
retriever = db.as_retriever(search_kwargs={"k":3})
def execute_query(query):
print(f"Your query: {query}")
print("[+] Getting relevant documents for query..")
relevant_docs = retriever.get_relevant_documents(query)
from langchain.chains.question_answering import load_qa_chain
from langchain.llms import OpenAI
chain = load_qa_chain(OpenAI(temperature=0), chain_type="stuff")
result = chain.run(input_documents=relevant_docs, question=query)
print(result)
interact_manual(execute_query, query=widgets.Text(value='', placeholder='Type your query here', description='Query:', layout=text_layout));
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.llms import OpenAI
from langchain.prompts.prompt import PromptTemplate
import json
# Initialize your Langchain model
model = ChatOpenAI(model_name="gpt-4", temperature=0.3)
# Initialize your retriever (assuming you have a retriever named 'db')
retriever = db.as_retriever(search_kwargs={"k": 8})
# Define your custom template
custom_template = """You are an AI assistant specialized in MITRE ATT&CK and you interact with a threat analyst, answer the follow up question. If you do not know the answer reply with 'I am sorry'.
Chat History:
{chat_history}
Follow Up Input: {question}
Answer: """
CUSTOM_QUESTION_PROMPT = PromptTemplate.from_template(custom_template)
# Initialize memory for chat history
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
# Initialize the ConversationalRetrievalChain
qa_chain = ConversationalRetrievalChain.from_llm(model, retriever, condense_question_prompt=CUSTOM_QUESTION_PROMPT, memory=memory)
def execute_conversation(question):
# Load conversational history from file
try:
with open('conversational_history.json', 'r') as f:
conversational_history = json.load(f)
except FileNotFoundError:
conversational_history = []
# Update conversational history with the user's question
conversational_history.append(("User", question))
# Use the ConversationalRetrievalChain to get the answer
result = qa_chain({"question": question})
# Extract the 'answer' part from the result
response_text = result.get('answer', 'Sorry, I could not generate a response.')
# Update conversational history with the bot's response
conversational_history.append(("Bot", response_text))
# Limit the history to the last 10 turns
if len(conversational_history) > 10:
conversational_history = conversational_history[-10:]
# Save conversational history to file
with open('conversational_history.json', 'w') as f:
json.dump(conversational_history, f)
# Save conversational history to file
with open('conversational_history.json', 'w') as f:
json.dump(conversational_history, f)
# Print only the last message in the conversational history
last_message = conversational_history[-1]
print(f"Discussion:\n{last_message[0]}: {last_message[1]}")
# Call the function with a question
execute_conversation("Who is Lazarus?")
execute_conversation("List all the techniques used by this group")
execute_conversation("Tell me more about the third point you mentionned")