Aristotle – The WhatsApp AI Bot in Python

AI Aristotle

Playing around with the Chat GPT, Bard user interfaces has been fun but I end up using the Meta AI chat a lot more since the interface it uses seems to slide smoothly into my daily workflow.  The MetaAI is great at this and is super useful for daily tasks and lookups. I also wanted to have the option of using the OpenAI API as the GPT models seems to give me the best results atleast for some of my more random queries. An unexpected benefit is the ability of using this on Whatsapp when I’m traveling internationally as the MetaAI on whatsapp seems to work only in the US market currently.

Enter Aristotle – a Whatsapp-based bot that plugs into the open AI model of your choosing to give you the result you want. Granted its as easy to flip to a browser on your phone and get the same info but I decided to replicate this functionality for kicks as this seems to be the form-factor of choice for me instead of loading a browser and logging into the interface etc.

This was also a good opportunity to test out the Assistant API and function calling. I also had a Ubuntu machine lying around at home that I wanted to put some use to and considering a small simple usecase like this – all it needs to do is stay awake and make the API calls when triggered. Credit to the openAI cookbook for a lot of the boiler plate.

Steps

The basic flow was something like this:

  1. Use the open AI playground  and set up the Assistant. Create a get_news function with the variables you can reference in the code
  2. Grab your Open AI Key and the Assistant API ID
  3. Create a python function to grab the latest news from an API (I used newsapi.org)
  4. Use the Assistant API to write your script that will respond to your query and trigger the function if asked for news
  5. Wrap the function in a flask server that you can run on a server
  6. Set up a Twilio sandbox to obtain a sandbox Whatsapp number and connect to it via your phone
  7. Start chatting!

Assistant

Setting up the playground was pretty easy with something like below.

Get News

# Function to fetch news articles
def fetch_news(skeyword):
    url = f'https://newsapi.org/v2/everything?q={skeyword}&sortBy=popularity&pageSize=5&apiKey={news_api_key}'
    response = requests.get(url)
    if response.status_code == 200:
        answer = response.json()
        
        return answer
    return "No response from the news API"

This function etches news articles from the News API. It takes a single argument, `skeyword`, which is the search keyword used to find relevant news articles.

Set up the AI bot

chat_with_bot function is the core of the chatbot. It starts by getting the user’s message from a form submission. It then creates a new thread with OpenAI and appends the user’s message to the thread to keep the context.. The assistant is then run with the given instructions that give it the personality. The function then waits until the assistant’s run is no longer queued or in progress. If the run status is “requires_action”, it executes the required functions and submits their outputs back to the assistant. This process is repeated until the run is no longer queued, in progress, or requires action. After the run is complete, it retrieves the messages from the thread and adds them to the conversation history. Finally, it gets a response from ChatGPT and returns it. The comments in the code should be self-explanatory. I did notice the whatsapp messages returning empty if the token size was not limited to whatever length Twilio can handle. I finally settled on 250 tokens as that seemed to work for most of my queries.

def chat_with_bot():
   # Obtain the request's user message
    question = request.form.get('Body', '').lower()
    print("user query ", question)
    thread = openaiclient.beta.threads.create()
    try:
       # Insert user message into the thread
        message = openaiclient.beta.threads.messages.create(
            thread_id=thread.id,
            role="user",
            content=question
        )
        # Run the assistant
        run = openaiclient.beta.threads.runs.create(
            thread_id=thread.id,
            assistant_id=assistant_id,
            instructions=instructions
        )
        # Show assistant response
        run = openaiclient.beta.threads.runs.retrieve(
            thread_id=thread.id,
            run_id=run.id
        )
        # Pause until it is not longer queued
        count = 0
        while (run.status == "queued" or run.status == "in_progress" and count < 5):
            time.sleep(1)
            run = openaiclient.beta.threads.runs.retrieve(
                thread_id=thread.id,
                run_id=run.id
            )
            count = count + 1
        if run.status == "requires_action":
            # Obtain the tool outputs by running the necessary functions
            aristotle_output = run_functions(run.required_action)
            aristotle_output = run_functions(run.required_action)
            # Submit the outputs to the Assistant
            run = openaiclient.beta.threads.runs.submit_aristotle_output(
                thread_id=thread.id,
                run_id=run.id,
                aristotle_output=aristotle_output
            )
        # Wait until it is not queued
        count = 0
        while (run.status == "queued" or run.status == "in_progress" or run.status == "requires_action" and count < 5):
            time.sleep(2)
            run = openaiclient.beta.threads.runs.retrieve(
                thread_id=thread.id,
                run_id=run.id
            )
            count = count + 1
        # After run is completed, fetch thread messages
        messages = openaiclient.beta.threads.messages.list(
            thread_id=thread.id
        )
        # TODO look for method call in data in messages and execute method
        print(f"---------------------------------------------")
        print(f"THREAD MESSAGES: {messages}")
        # With every message, include user's message into the conversation history
   
        for message in messages:  # Loop through the paginated messages
            system_message = {"role": "system",
                              "content": message.content[0].text.value}
            # Append to the conversation history
            conversation_history.append(system_message)
        # Get the response from ChatGPT
        answer = chat_with_bot(question, conversation_history, instructions)
        print(f"---------------------------------------------")
        print(f"ARISTOTLE: {answer}")
        return str(answer)
    except Exception as e:
        answer = 'Sorry, I could not process that.'
        print(f"An error occurred: {e}")

run_functions

This function is used to execute any required actions that the assistant needs to perform. It loops through the tool calls in the required actions, gets the function name and arguments, and calls the corresponding Python function if it exists. The function’s output is then serialized to JSON and added to the list of tool outputs, which is returned at the end of the function.

The functions are wrapped in a flask server that runs 24 7 on a server I have running around at home. Considering the whopping number of active users (i.e. 1, server load is not really a concern.

TWILIO

https://console.twilio.com/ is super easy to setup a whatsapp connection. Twilio’s documentation and interface makes it super easy to set this up.

Once you have the whatsapp Sandbox set up and the free credits make it super easy to get start, the below code helps connect the application with twilio and whatsapp.

twilio_response = MessagingResponse()
    reply = twilio_response.message()
    answer = chat_with_bot()
    
    numbers = [TWILIO_FROM_NUMBER, TWILIO_TO_NUMBER]  
    # add all the numbers to which you want to send a message in this list
    for number in numbers:
        message = client.messages.create(
        body=answer,
        from_='whatsapp:+'+TWILIO_FROM_NUMBER,
        to='whatsapp:+'+number,
    )
    
    message = client.messages.create(
        body=answer,
        from_='whatsapp:+'+TWILIO_FROM_NUMBER,
        to='whatsapp:+'+TWILIO_TO_NUMBER
        
        
    )
    print(message.sid)

So far, this has been useful to switch out models especially during travel. Some things I would love to do to add to this:

  1. Add more custom functions
  2. Speed up the backend by running on a beefier server
  3. Tune the application for faster responses

Code here: https://github.com/vishwanath79/aristotle-whatsapp

Deep Learning generated guitar solo on Hi Phi Nation podcast

One of my older tracks (trained on a LSTM) to generate a 80s shred inspired guitar solo ( mostly done for laughs and quite embarrassing in hindsight šŸ™‚ ) got featured on the Hi Phi Nation podcast. You can listen to the episode here:

Great episode exploring AI and machine-generated music technologies and AI-generated compositions. My AI-generated solo ( around the 45 minute mark) faced off against two guitarists which was fun and also hilarious with comments from the guitarists on the AI solo such as “reminded  me  of  like droids  from  Star  Wars”, “No structure to it” etc).

Generating AI-driven creativity is an extremely intriguing topic and my personal research continues on embedding AI models in chord or note sequences to enhance creativity. I think the peak experience of personalized music composition using AI has to be a human/machine collaboration (atleast until singularity occurs and blows our collective understanding of reality away). This is an exciting space with my latest muse being the open source libraries being released such as audiocraft that are powered by dynamic music generation language model. More experiments incoming.

PS: I had a whole blog post on the mechanics of generating the solo here.

Tenacity for retries

Most times when I have taken over new roles or added more portfolios to my current workstreams, it has usually involved decision making to either build on or completely overhaul legacy code. I’ve truly encountered some horrific code bases that are no longer understandable, secure or  scaleable. Often if there is no ownership of the code or if attrition has rendered it an orphan, it usually sits unnoticed ( like an argiope spider waiting for its prey) until it breaks (usually on a friday afternoon) when someone has to take ownership of it and bid their weekend plans goodbye. Rarely do teams have  a singular coding style with patterns and practices clearly defined that are repeatable and can withstand change people or technology change – if you are in such an utopian situation, consider yourself truly lucky!

A lot of times,  while you have to understand and reengineer the existing codebase, it is imperative to keep the lights on while you are figuring out the right approach, a sorta situation while you got to keep the car moving while changing its tires. I discovered Tenacity while  encountering a bunch of  gobbledygook shell and python scripts crunched together that ran a critical data pipeline and had to figure out a quick retry logic to keep the job running  while it randomly failed due to environment issues ( exceptions, network timeouts, missing files, low memory and every other entropy-inducing situation on the platform). The library can handle different scenarios based on your usecase.

  • Continuous retrying for your commands. This usually works if its a trivial call where you have minimal concerns over overwhelming target systems. Here, the function can be retried until no exception is returned.
  • Stop after a certain number of tries. Example – if you have a shell script like shellex below that needs to execute and you anticipate delays or issues with the target URL or if a command you run raises an exception, you can have Tenacity retry based on your requirement .
  • Exponential backoff patterns – To handle cases where the delays maybe too long or too short for a fixed wait time, tenacity has a wait_exponential algo that can ensure that if a request cant succeed in a short time after a retry, the application can wait longer and longer with each retry thus alleviating the target system of repetitive fixed time retries.

The library handles plenty of others uses cases like Error handling, custom callbacks and tons more.

Overall, this has been a great find to use a functional library like Tenacity to for various usecases instead of writing custom retry logic or implementing a new orchestration engine for handling retries.

Asynchronicity with asyncio and aiohttp

The usual Synchronous versus Asynchronous versus Concurrent versus Parallel is a topic in technical interviews that usually leads to expanded conversations in the candidates overall competency on scaling and leads to interesting rabbit holes and examples. While keeping the conversations open-ended, I’ve noticed when candidates usually incorporate techniques to speed up parallelism or enhance concurrency or mention other ways of speeding up processing, its a great sign.

Its also important to distinguish between CPU-bound and IO-bound tasks in such situations since parallelism is effective on CPU-bound tasks (example preprocessing data, running ensemble of models) while concurrency works best for IO-bound tasks (web scraping, database calls). CPU-bound tasks are great for parallelization using multiple CPUs where a task can be split into multiple subtasks while IO-bound are not CPU dependent but depend on time reading/writing from disk.

Key standard libraries in Python for concurrency include:

  • AsyncIO – for concurrency with coroutines and eventloops. This is most similar to a pub-sub model.
  • Concurrent.futures –  for concurrency via threading

Both are limited by the global interpreter lock (GIL) and single process, multi-threaded.
Note for parallelism, the library I’ve usually used is multiprocessing which is another post.

AsyncIO  is a great tool to execute tasks concurrently and is a great way to add asynchronous calls to your program. However, the right usecase matters as it can also lead to unpredictability since tasks can start, run and complete in overlapping times using context switching between threads. Threads can be blocked using asyncio and the next available thread in the queue can be processed until it completes or is blocked. The key here is the lack of any manual checking if the thread is freed up, as asyncio will announce the availability of the thread when it does actually free up.

This post has a couple of quick examples of asyncio using the async/await syntax  for event-loop management. Plenty of other libraries available but async is usually sufficient for most workloads and a couple of simple examples go a long way in explaining the concept.

The key calls here are –

  • async – tells the Python interpreter to run the coroutine asynchronously with an event loop
  • await – while waiting for the result to be returned, this passes control back to the event loop, suspending the execution of the current coroutine to let the event loop run other things until the await has a result returned
  • run – Schedule a coroutine (python 3.7+) . In earlier versions (3.5/3.6) you  can a
  • sleep – suspends execution of a task to switch to the other task
  • gather – Execute multiple coroutines that need to finish before resuming the current context with the list of responses from each coroutine

Example of using async for multiple tasks:

A trivial example like above goes a long way in explaining the core usage of the sync library especially while bolting it onto long running python processes that are primarily slow due to IO.

aiohttp is another great resource for synchronous HTTP requests which by nature are a great usecase for asynchronicity while requests wait for servers to respond to do other tasks. This basically works by creating a client session that can be used to support multiple individual requests and make connections upto 100 different servers at the same time.

Non async example

A quick example to handle requests from a website (https://api.covid19api.com/total/country/{country}/status/confirmed) that provides a JSON string based on the specific request. The specific request is not important here and used only for demonstrative purposes to demonstrate the async functionality.

Async example using aiohttp which will be needed in order to asynchronously call the same endpoint.

The example clearly shows the time difference where the async calls halve the time taken for the same calls. Granted its a trivial example but it shows the benefit of the non-blocking Ascync call and can be applied to any situation that deals with multiple requests and calls to different servers.

Key Points

  • AsyncIO is usually a great fit for IP bound problems
  • Putting async before every function will not be beneficial as the blocking calls can slow down the code so validate the usecase first
  • async await support a specific set of methods only so for specific calls (say to databases), the specific python wrapper library you use will need to support async await.
  • https://github.com/timofurrer/awesome-asyncio is the go-to place for higher level async APIs along with https://docs.python.org/3/library/asyncio-task.html

Merkle Trees for Comparisons – Example

Merkle trees (named after Ralph Merkle, one of the fathers of modern cryptography) are fascinating data structures used in hash based data structures to verify the integrity of data in peer-to-peer systems. Systems like Dynamo use this to compare hashesĀ  – essentially itself a binary tree of hashes and typically used to remove conflicts for reads. For example – in a distributed system, if a replica node falls considerably behindĀ  its peers, using techniques like vector clocks might take unacceptable times to resolve. A hash-based comparison approach like Merkle tree would help quickly compare two copies of a range of data on different replicas. This is also a core part of blockchains like Ethereum which uses a non-binary variant but the binary ones are the most common and easy to understand and fun to implement.

Conceptually this involves:

  1. Comparing the root hashes of both trees.
  2. Continue recursion on the left and right children of the tree until the root hashes are equal.

The “Merkle root” stores the summary of all the transaction value in a singular value.

Simple Example

For example , if TA, TB,TC ,TD are transactions ( could be files, keys etc) and H is a Hash function. You can construct a tree by taking the transactions, hashing their concatenated values to generate children and finally reduced to a single root. In my scrawl above, this means hashing TA and TB, TC and TD, then hashing their concatenations H (AB), H(CD) to land at H(ABCD).Essentially keep hashing the until all the transactions meet at a single hash.

Example

Here’s an example that uses this technique to compare two files by generating their Merkle root to validate if they are equal of not (comments inline).

Invoke the script by calling “python merkle_sample.py “<file1>.csv” “<file2>.csv” to compare two merkle trees. Code below:

Key advantage here is that each branch of the tree can be checked independently without downloading the whole dataset to compare.

This translates to reducing the number of disk reads for synchronization though that efficiency needs to be balanced against the recalculation of the entire tree when nodes leave or go down. This is fundamental to Crypto currencies when transactions need to be validated by nodes and there is enormous time and space cost to validate every transaction which can be mitigated by Merkle trees in logarithmic time instead of linear time.  The Merkle root get put into the block header that gets hashed in the process of mining and comparisons are made via the Merkle root rather than submitting all the transactions over the network. Ethereum uses a more complex variant of the Merkle, namely the Merkle Patricia tree.

The applications of this range beyond blockchains to Torrents, Git, Certificates and more.

Text Summarizer on Hugging Face with mlflow

Hugging Face Emoji Classic Round Sticker - EmojiPrints

Hugging Face is the go-to resource open source natural language processing these days. The Hugging Face hubs are an amazing collection of models, datasets and metrics to get NLP workflows going. Its relatively easy to incorporate this into a mlflow paradigm if using mlflow for your model management lifecycle. mlflow makes it trivial to track model lifecycle, including experimentation, reproducibility, and deployment. mlflow’s open format makes it my go-to framework for tracking models in an array of personal projects and It also has an impressive enterprise implementation that my teams at work enable for large enterprise use cases. For smaller projects, its great to use mlflow locally for any projects that requires model management as this example illustrates.

The beauty of Hugging Face (HF) is the ability to use their pipelines to to use models for inference. The models are products of massive training workflows performed by big tech and available to ordinary users who can use them for inference. The HF pipelines offer a simple API dedicated to performing inference in these models thus sparing the ordinary the user the complexity and compute / storage requirements for running such large models.

The goal was to put some sort of tracking around all my experiments with the Hugging Face Summarizer that I’ve been using to  summarize text and then use the mlflow Serving via REST as well as running predictions on the inferred model by passing in a  text file. Code repository is here with snippets below.

Running the Text Summarizer and calling it via curl

Text summarization consists of Extractive and Abstractive types where Extractive selects sentence that has the most valuable context while Abstractive is trained to create summaries.

Considering I was running on a CPU, I picked a small model like the T5-small model trained on Wikihow All data set that has been trained to write summaries. The boiler plate code on the HuggingFace website gives you all you need to get started. Note that this models input length is set to 512 tokens max which may not be optimum for usecases with larger text.

a) First step is to define a wrapper around the model code so it can be called easily later on by subclassing it with the mlflow.pyfunc.PythonModel to use custom logic and artifacts.

class Summarizer(mlflow.pyfunc.PythonModel):
    '''
    Any MLflow Python model is expected to be loadable as a python_function model.
    '''

    def __init__(self):
        from transformers import pipeline, AutoTokenizer, AutoModelWithLMHead

        self.tokenizer = AutoTokenizer.from_pretrained(
            "deep-learning-analytics/wikihow-t5-small")

        self.summarize = AutoModelWithLMHead.from_pretrained(
            "deep-learning-analytics/wikihow-t5-small")

    def summarize_article(self, row):
        tokenized_text = self.tokenizer.encode(row[0], return_tensors="pt")

        # T5-small model trained on Wikihow All data set.
        # model was trained for 3 epochs using a batch size of 16 and learning rate of 3e-4.
        # Max_input_lngth is set as 512 and max_output_length is 150.
        s = self.summarize.generate(
            tokenized_text,
            max_length=150,
            num_beams=2,
            repetition_penalty=2.5,
            length_penalty=1.0,
            early_stopping=True)

        s = self.tokenizer.decode(s[0], skip_special_tokens=True)
        return [s]

    def predict(self, context, model_input):
        model_input[['name']] = model_input.apply(
            self.summarize_article)

        return model_input

b) We define the tokenizer to prepare the inputs of the model and the model using the HuggingFace specifications. This is a smaller model trained on Wikihow All data set. From the documentation – the model was trained for 3 epochs using a batch size of 16 and learning rate of 3e-4. Max_input_length is set as 512 and max_output_length is 150.

c) Then define the model specifications of the T5-small model by calling the summarize_article function with the tokenized text that will called it for every row in the dataframe input and eventually return the prediction.

d) The prediction function calls the summarize_article providing the  model input and calling the summarizer and returns the prediction. This is also where we can plug in mlflow  to infer the predictions.

The input and output schema are defined in the ModelSignature class as follows :

# Input and Output formats
input = json.dumps([{'name': 'text', 'type': 'string'}])
output = json.dumps([{'name': 'text', 'type': 'string'}])
# Load model from spec
signature = ModelSignature.from_dict({'inputs': input, 'outputs': output}) input = json.dumps([{'name': 'text', 'type': 'string'}])
 output = json.dumps([{'name':'text', 'type':'string'}]) 


e) We can set mlflow operations by setting the tracking URI which was “” in this case since its running locally. Its trivial in a platform like Azure to spin up a databricks workspace and get a tracking server spun up automatically so you can persist all artifacts at cloud scale.


Start tracking the runs by wrapping the mlflow.start_run invocation. The key here is to call the model for inference using the mlflow.pyfunc function to make the python code load into mlflow. In this case , the dependencies of the model are all stored directly with the model. Plenty of parameters here that can be tweaked described here.

# Start tracking
with mlflow.start_run(run_name="hf_summarizer") as run:
    print(run.info.run_id)
    runner = run.info.run_id
    print("mlflow models serve -m runs:/" +
          run.info.run_id + "/model --no-conda")
    mlflow.pyfunc.log_model('model', loader_module=None, data_path=None, code_path=None,
                            conda_env=None, python_model=Summarizer(),
                            artifacts=None, registered_model_name=None, signature=signature,
                            input_example=None, await_registration_for=0)


f) Check the runs via mlflow UI either using the “mlflow ui” command or just invoke the commandmlflow models serve -m runs:/<run_id>


g) Thats it – Call the curl command using sample text below:

curl -X POST -H "Content-Type:application/json; format=pandas-split" --data '{"columns":["text"],"data":[["Howard Phillips Lovecraft August 20, 1890 ā€“ March 15, 1937) was an American writer of weird and horror fiction, who is known for his creation of what became the Cthulhu Mythos.Born in Providence, Rhode Island, Lovecraft spent most of his life in New England. He was born into affluence, but his familys wealth dissipated soon after the death of his grandfather. In 1913, he wrote a critical letter to a pulp magazine that ultimately led to his involvement in pulp fiction.H.P.Lovecraft wrote his best books in Masachusettes."]]}' http://127.0.0.1:5000/invocations

Output:

"name": "Know that Howard Phillips Lovecraft (H.P.Lovecraft was born in New England."}]%

Running the Text Summarizer and calling it via a text file

For larger text, its more convenient reading the text from a file, formatting it and running the summarizer on it. The predict_text.py does exactly that.


a) Clean up the text in article.txt and load the text into a dictionary.

b) Load the model using pyfunc.load_model and then run the model.predict on the dictionary.

# Load model as a PyFuncModel.
loaded_model = mlflow.pyfunc.load_model(logged_model)


# Predict on a Pandas DataFrame.
summary = loaded_model.predict(pd.DataFrame(dict1, index=[0]))

print(summary['name'][0])

Code here

In summary, this makes for a useful way to track models and outcomes from readily available transformer pipelines to pick the best ones for the task.

Spotify Recommender API Call

One of my favorite features in Spotify are the recommendations. The app’s recommendations includes the Discover Weekly,  Daily Mix, Release Radar and the Artist Radio features. I could go through hours of recommendations substituting white noise while working on projects and usually encounter a song or an artist that appeals to my guitar/keyboard driven sensibilities in a session. While Discover Weekly, Daily Mix yield gems once in a while, the song specific ones usually based on Artist / Song radio yield a lot more matches to my sensibilities.

The recommendations endpoints that generates reccs based on a seed is a favorite. Iā€™ve usually had a good match rate with songs that ā€œstickā€ based on the API. There are plenty of other endpoints (artists, songs  etc) that could be easily plugged in to generate relevant predictions.

The API documentation of Spotify has always been stellar and its usability is enhanced by being able to test all the API calls easily within their developer console.

This API also has a bunch of parameters that can be configured for fine-tuning the recommendation: key, genre, loudness, energy, instrumentalness, popularity, speechiness, danceability etc.

Per the official docs – ā€œRecommendations are generated based on the available information for a given seed entity and matched against similar artists and tracks. If there is sufficient information about the provided seeds, a list of tracks will be returned together with pool size details. For artists and tracks that are very new or obscure there might not be enough data to generate a list of tracks.ā€

One of the key things here is to generate seeds for the recommendations, this can be done by using endpoints like Get a Userā€™s Top Artists and Tracks to obtain artists and tracks based on my listening history and use these artists and tracks as seeds for the Get Recommendations Based on Seeds endpoint. This endpoint will only return tracks. 
The Web API Authorization Guide is a must to read before querying these endpoints and the developer console makes it super easy to try out different endpoints. 

I wanted a quick way to query the recommendations API for new recommendations and the combination of the streamlit + Spotify API was quick simple solve to get that working. At a high level I wanted to be able to query a song or artist and generate recommendations based on it. A secondary need is also to collect data for a reccomender I am training to customize ML-driven reccomendations but more on that in a different post.

A lot of the code is boilerplate and pretty self explanatory but at a high level it consists of the class to interact with the Spotify API (spotify_api.py) ,  a UI wrapper using Streamlit to render the app (spotify_explorer.py).  Given a client id and client secret, spotify_api.py gets client credentials from the Spotify API to invoke the search. Sample code inline with comments. The code can obviously be much more modular and pythonic but for investing a quick hour of hacking, this got the job done.

class SpotifyAPI(object):
    access_token = None
    access_token_expires = datetime.datetime.now()
    access_token_did_expire = True
    client_id = None
    client_secret = None
    token_url = 'https://accounts.spotify.com/api/token'

    def __init__(self, client_id, client_secret, *args, **kwargs):
        self.client_id = client_id
        self.client_secret = client_secret

    # Given a client id and client secret, gets client credentials from the Spotify API.
    def get_client_credentials(self):
        ''' Returns a base64 encoded string '''
        client_id = self.client_id
        client_secret = self.client_secret
        if client_secret == None or client_id == None:
            raise Exception("check client IDs")
        client_creds = f"{client_id}:{client_secret}"
        client_creds_b64 = base64.b64encode(client_creds.encode())
        return client_creds_b64.decode()

    def get_token_header(self):  # Get header
        client_creds_b64 = self.get_client_credentials()
        return {"Authorization": f"Basic {client_creds_b64}"}

    def get_token_data(self):  # Get token
        return {
            "grant_type": "client_credentials"
        }

    def perform_auth(self):  # perform auth only if access token has expired
        token_url = self.token_url
        token_data = self.get_token_data()
        token_headers = self.get_token_header()

        r = requests.post(token_url, data=token_data, headers=token_headers)

        if r.status_code not in range(200, 299):
            print("Could not authenticate client")
        data = r.json()
        now = datetime.datetime.now()
        access_token = data["access_token"]
        expires_in = data['expires_in']
        expires = now + datetime.timedelta(seconds=expires_in)
        self.access_token = access_token
        self.access_token_expires = expires
        self.access_token_did_expire = expires < now
        return True

    def get_access_token(self):

        token = self.access_token
        expires = self.access_token_expires
        now = datetime.datetime.now()
        if expires < now:
            self.perform_auth()
            return self.get_access_token()
        elif token == None:
            self.perform_auth()
            return self.get_access_token()
        return token

    # search for an artist/track based on a search type provided
    def search(self, query, search_type="artist"):
        access_token = self.get_access_token()
        headers = {"Content-Type": "application/json",
                   "Authorization": f"Bearer { access_token}"}
        # using the  search API at https://developer.spotify.com/documentation/web-api/reference/search/search/
        search_url = "https://api.spotify.com/v1/search?"
        data = {"q": query, "type": search_type.lower()}
        from urllib.parse import urlencode
        search_url_formatted = urlencode(data)
        search_r = requests.get(
            search_url+search_url_formatted, headers=headers)
        if search_r.status_code not in range(200, 299):
            print("Encountered isse=ue")
            return search_r.json()
        return search_r.json()

    def get_meta(self, query, search_type="track"):  # meta data of a track
        resp = self.search(query, search_type)
        all = []
        for i in range(len(resp['tracks']['items'])):
            track_name = resp['tracks']['items'][i]['name']
            track_id = resp['tracks']['items'][i]['id']
            artist_name = resp['tracks']['items'][i]['artists'][0]['name']
            artist_id = resp['tracks']['items'][i]['artists'][0]['id']
            album_name = resp['tracks']['items'][i]['album']['name']
            images = resp['tracks']['items'][i]['album']['images'][0]['url']

            raw = [track_name, track_id, artist_name, artist_id, images]
            all.append(raw)

        return all

  

The get_recommended_songs function is the core of the app querying the API for results based on the query passed in. The more the parameters the better the results. Customizing the call to any API call is fairly trivial.

   def get_reccomended_songs(self, limit=5, seed_artists='', seed_tracks='', market="US",
                              seed_genres="rock", target_danceability=0.1):  # reccomendations API
        access_token = self.get_access_token()
        endpoint_url = "https://api.spotify.com/v1/recommendations?"
        all_recs = []
        self.limit = limit
        self.seed_artists = seed_artists
        self.seed_tracks = seed_tracks
        self.market = market
        self.seed_genres = seed_genres
        self.target_danceability = target_danceability

        # API query plus some additions
        query = f'{endpoint_url}limit={limit}&market={market}&seed_genres={seed_genres}&target_danceability={target_danceability}'
        query += f'&seed_artists={seed_artists}'
        query += f'&seed_tracks={seed_tracks}'
        response = requests.get(query, headers={
                                "Content-type": "application/json", "Authorization": f"Bearer {access_token}"})
        json_response = response.json()

        # print(json_response)
        if response:
            print("Reccomended songs")
            for i, j in enumerate(json_response['tracks']):
                track_name = j['name']
                artist_name = j['artists'][0]['name']
                link = j['artists'][0]['external_urls']['spotify']

                print(f"{i+1}) \"{j['name']}\" by {j['artists'][0]['name']}")
                reccs = [track_name, artist_name, link]
                all_recs.append(reccs)
            return all_recs

Wrapping both the calls in a Streamlist app is refreshingly simple and dockerizing and pushing to Azure container registry was trivial.

Code

https://github.com/vishwanath79/spotifier

Usage

To run the app, run:
streamlit run spotify_explorer.py

Deployed at

https://spotiapp2.azurewebsites.net/

Part 2 to follow at some point as I continue building out a custom recommender that compares the current personalizer with a custom personalizer that takes in Audio features and more personalized inputs and tuneable parameters.

ONNX for ML Interoperability

Having been a Keras user since  I read  the seminal Deep Learning with Python , I’ve been experimenting with exporting formats to different frameworks to be more framework-agnostic.


ONNX ( Open Neural Network Exchange) is an open format for representing traditional and deep learning ML models.  Key goal being promoting inter-operability between a variety of frameworks and target environments. ONNX helps you to export a fully trained model into its format and enables targeting diverse environments without you doing manual optimization and painful rewrites of the models to accommodate environments.
It defines an extensible computation graph model along with built-in operators and standard data types to allow for a compact and cross-platform representation for serialization. A typical use case could be scenarios where you want to use transfer learning to use model weights of another model possibly built in another framework into your own model i.e. if you build  a model in Tensorflow, you get a protobuf (PB) file as output and it would be great if there is one universal format that you can now convert to the PT format to load and reuse in Pytorch or use its own hardware agnostic runtime.

For high-performance inference requirements in varied frameworks, this is great with platforms like NVIDIA’s TensorRT supporting ONNX with optimizations aimed at the accelerator present on their devices like the Tesla GPUs or the Jetson embedded devices.

Format

The ONNX file is a protobuf encoded tensor graph. List of operators supported are documented here and operations are referred to as “opsets” i.e. operation sets. Opsets are defined for different runtimes in order to enable interoperability. The operations are a growing list of widely used linear operations, functions and other primitives used to deal with tensors.

The operations include most of the typical deep learning primitives, linear operations, convolutions and activation functions. The model is mapped to the ONNX format by executing the model with often just random input data and tracing the execution. The operations executed are mapped to ONNX operations and so the entire model graph is mapped into the ONNX format. After this the ONNX model is then saved as .onnx protobuf file which can be read and executed by a wide and growing range of ONNX runtimes.

Note – Opsets are fast evolving and with fast release cycles of competing frameworks, it may not always be easy to upgrade to the latest ONNX version if it breaks compatibility with other frameworks. The file format consists of the following:

  • Model: Top level construct
    • Associates version Info and Metadata with a graph
  • Graph: describes a function
    • Set of metadata fields
    • List of model parameters
    • List of computation nodes – Each node has zero or more inputs and one or more outputs.
  • Nodes: used for computation
    • Name of node
    • Name of an operator that it invokes a list of named inputs
    • List of named outputs
    • List of attributes

More details here.

Runtime


The ONNX model can be inferenced with ONNX runtime that uses a variety of hardware accelerators for optimal performance. The promise of ONNX runtime is that it abstracts the underlying hardware to enable developers to use a single set of APIs for multiple deployment targets. Note – the ONNX runtime is a separate project and aims to perform inference for any prediction function converted to the ONNX format.

This has  advantages over dockerized pickle models that is usually the approach in a lot of production deployments where there are runtime restrictions (i.e. can run only in .NET or JVM) , memory and storage overhead, version dependencies, and batch prediction requirements.


ONNX runtime has been integrated in WINML, Azure ML with MSFT as its primary backer. Some of the new enhancements include INT8 quantization to reduce floating point numbers for reducing model size, memory footprint and to increase efficiencies benchmarked here.


The usual path to proceed :

  • Train models with frameworks
  • Convert into ONNX with ONNX converters
  • Use onnx-runtime to verify correctness and Inspect network structure using netron (https://netron.app/)
  • Use hardware-accelerated inference with ONNX runtime ( CPU/GPU/ASIC/FPGAs)

Tensorflow


To convert Tensorflow models, the easiest way is to use the tf2onnx tool from the command line. This converts the saved model to a model representation that includes the inference graph.


Here is an end-to-end example of saving a simple Tensorflow model , converting it to ONNX and then running the predictions using the ONNX model and verifying the predictions match.


Challenges


However, some things to consider while using this format is the lack of “official support” from frameworks like Tensorflow. For example, Pytorch does provide the functionality to exports models into ONNX (torch.ONNX ) however I could not find any function to import an ONNX model to out put a Pytorch model. Considering CAFFE 2 that is a part of PyTorch fully supports ONNX import/export, it may not be totally unreasonable to expect an official conversion importer(there is a proposal already documented here).

The Tensorflow converters seem to be part of the ONNX project i.e. not an official/out of the box Tensorflow implementation. List of Tensorflow Ops supported are documented here. The github repo is a treasure trove of information on the computation graph model and the operators/data types that power the format. However, as indicated earlier depending on the complexity of the model (especially in transfer learning scenarios), it’s likely to encounter conversion issues during function calls that may cause the ONNX converter to fail. In this case, there are likely scenarios which may necessitate modifying the graph in order to fit the format. I’ve had a few issues running into StatefulPartitionednCalls especially in using TransferLearning situations for larger encoders in language models.


I have also had to convert Tensorflow to PyTorch by first converting Tensorflow to ONNX. Then the ONNX models to Keras using onnx2keras and then convert to Pytorch using MMdn with mixed results and a lot of debugging and many abandons. However, I think ONNX runtime for inference rather than framework-to-framework conversions will be a better use of ONNX.



The overall viability of a universal format like ONNX though well intentioned and highly sought may not fully ever come into fruition with so many divergent interests amongst the major contributors and priorities though its need cannot be disputed.

Deep Learned Shred Solo

Music generation with Recurrent Neural Nets has been of great interest to me with projects like Magenta displaying amazing feats of ML-driven creativity. AI is increasingly being used to augment human creativity and this trend will lay to rest creativity blocks like in the future. As someone who is usually stuck in a musical rut, this is great for spurring creativity.

With a few covid-induced reconnects with old friends (some of whom are professional musicians) and some inspired late night midi programming on Ableton, I decided to modify some scripts / tutorials that have been lying around on my computer to blend deep learning and compose music around it as I research on the most optimal ways to integrate Deep Learning into original guitar music compositions.

Thereā€™s plenty of excellent blogs and code on the web on LSTMs including this one and this one on generating music using Keras. LSTMs have plenty of boiler plate code on github that demonstrate LSTM and GRUs for creating music. For this project, I was going for recording a guitar solo based on artists I like and to set up a template for future experimentation for research purposes. A few mashed up solos of 80s guitar solos served as the source data but the source data could have been pretty much anything in the midi format and it helps to know how to manipulate these files in the DAW, which in my case was Ableton. Most examples on the web have piano midi files that generate music in isolation. However, I wanted to combine the generated music with minimal accompaniment so as to make it ā€œrealā€.

With the key of the track being trained on being in F Minor , I also needed to make sure i have some accompaniment in the key of FMinor for which I recorded a canned guitar part with some useful drum programming thanks to EZDrummer.

Tracks in Ableton

Note: this was for research purposes only and for further research into composing pieces that actually make sense based on the key being fed into the model. 

Music21 is invaluable for manipulating midi via code. Its utility is that is lets us manipulate starts, durations and pitch. I used Ableton to use the midi notes generated to plug in an instrument along with programmed drums and rhythm guitars.

Step 1:

Find the midi file(s) you want to base your ML solo on. In this case, Im going for generating a guitar solo to layer over a backing track. This could be pretty much anything as long as its midi that can be processed by Music21.

Step 2:

Preprocessing the midi file(s): The original midi file had guitars over drums, bass and keyboards. So, the goal was to extract the list of notes first to save them, the instrument.partitionByInstrument() function, separates the stream into different parts according to the instrument. If we have multiple files we can loop over the different files to partition it by individual instrument. This returns a list of notes and chords in the file.

from tqdm import tqdm
songs = glob(' /ml/vish/audio_lstm/YJM.mid') # this could be any midi file to be trained
notes = []
for file in tqdm(songs):
    midi = converter.parse(file) # convert all supported data formates to music21 objects
    notes_parser = None
    try:
        # partition parts for each unique instrument
        parts = instrument.partitionByInstrument(midi)
    except:
        print("No uniques")

    if parts: 
        notes_parser = parts.parts[0].recurse()
    else:
        notes_parser = midi.flat.notes # flatten notes to get all the notes in the stream
        print("parts == None")

    for element in notes_parser:
        if isinstance(element, note.Note):# check if elements are in the note class
            notes.append(str(element.pitch))  # Returns  Pitch objects found as a Python List
        elif(isinstance(element, chord.Chord)):
          notes.append('.'.join(str(n) for n in element.normalOrder))  
    
print("notes:", notes)

Step 3:

Creating the model inputs: Convert the items in the notes list to an integer so they can serve as model inputs. We create arrays for the network input and output to train the model. We have 5741 notes in  our input data and have defined a sequence length of 50 notes. The input sequence will be 50 notes and the output array will store the 51st note for every input sequence that we enter. Then we reshape and normalize the input vector sequence. We also one hot encoder on the integers so that we have the number of columns equal to the number of categories to get a network output shape of  (5691, 92). I’ve commented out some of the output so the results are easier to follow.

pitch_names = sorted(set(item for item in notes))   # ['0', '0.3.7', '0.4.7', '0.5', '1', '1.4.7', '1.5.8', '1.6', 10', '10.1.5',..]
note_to_int = dict((note, number) for number, note in enumerate(pitch_names))  #{'0': 0,'0.3.7': 1, '0.4.7': 2,'0.5': 3, '1': 4,'1.4.7': 5,..]
sequence_length = 50
len(pitch_names) # 92
range(0, len(notes) - sequence_length, 1) #range(0, 5691)
# Deifne input and output sequence
network_input = []
network_output = []
for i in range(0, len(notes) - sequence_length, 1):
    sequence_in = notes[i: i + sequence_length]
    sequence_out = notes[i + sequence_length]
    network_input.append([note_to_int[char] for char in sequence_in]) 
    network_output.append(note_to_int[sequence_out])
print("network_input shape (list):", (len(network_input), len(network_input[0]))) #network_input shape (list): (5691, 50)
print("network_output:", len(network_output)) #network_output: 5691
patterns = len(network_input)  
print("patterns , sequence_length",patterns, sequence_length) #patterns , sequence_length 5691 50
network_input = np.reshape(network_input, (patterns , sequence_length, 1)) # reshape to array of (5691, 50, 1)
print("network input",network_input.shape) #network input (5691, 50, 1)
n_vocab = len(set(notes))
print('unique notes length:', n_vocab) #unique notes length: 92
network_input = network_input / float(n_vocab) 
# one hot encode the output vectors to_categorical(y, num_classes=None)
network_output = to_categorical(network_output)  
network_output.shape #(5691, 92)

Step 4:

Model: We invoke Keras to build out the model architecture using LSTM. Each input note is used to predict the next note. Code below uses standard model architecture from tutorials without too many tweaks. Plenty of tutorials online that explain the model way better than I can such as this: http://colah.github.io/posts/2015-08-Understanding-LSTMs/

Training on the midi input can be expensive and time consuming so I suggest setting a high epoch number with calls backs defined based on the metrics to monitor, In this case,  I used loss and also created checkpoints for recovery and save the model as ‘weights.musicout.hdf5’. Also note , I trained this on community edition Databricks for convenience.

def create_model():
  from tensorflow.keras.models import Sequential
  from tensorflow.keras.layers import Activation, Dense, LSTM, Dropout, Flatten

  model = Sequential()
  model.add(LSTM(128, input_shape=network_input.shape[1:], return_sequences=True))
  model.add(Dropout(0.2))
  model.add(LSTM(128, return_sequences=True))
  model.add(Flatten())
  model.add(Dense(256))
  model.add(Dropout(0.3))
  model.add(Dense(n_vocab))
  model.add(Activation('softmax'))
  model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=["accuracy"])
  model.summary()
  return model

from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
model = create_model()

save_early_callback = EarlyStopping(monitor='loss', min_delta=0,
                                    patience=3, verbose=1,
                                    restore_best_weights=True)
epochs = 5000
filepath = 'weights.musicout.hdf5'
checkpoint = ModelCheckpoint(filepath, monitor='loss', verbose=0, save_best_only=True)
model.fit(network_input, network_output, epochs=epochs, batch_size=32, callbacks=[checkpoint,save_early_callback])

Step 5:

Predict: Once we have the model trained, we can start generating nodes based on the trained model weights along with feeding the model a sequence of notes. We can pick a random integer and a random sequence from the input sequence as a starting point. In my case, it involved calling the model.predict function for a 1000 notes that can be converted to a midi file. The results might vary at this stage, for some reason I saw some degradation after 700 notes so some tuning required here.

start = np.random.randint(0, len(network_input)-1)  # randomly pick an integer from input sequence as starting point
print("start:", start)
int_to_note = dict((number) for number in enumerate(pitch_names))
pattern = network_input[start]
prediction_output = [] # store the generated notes
print("pattern.shape:", pattern.shape)
pattern[:10] # check shape

# generating 1000 notes

for note_index in range(1000):
    prediction_input = np.reshape(pattern, (1, len(pattern), 1))
    prediction_input = prediction_input / float(n_vocab)

    prediction = model.predict(prediction_input, verbose=0) # call the model predict function to predict a vector of probabilities
    
    predict_index = np.argmax(prediction)  # Argmax is finding out the index of the array that results in the largest predict value
    #print("Prediction in progress..", predict_index, prediction)
    result = int_to_note[predict_index]   
    prediction_output.append(result)

    pattern = np.append(pattern, predict_index)
    # Next input to the model
    pattern = pattern[1:1+len(pattern)]

print('Notes generated by model...')
prediction_output[:25] # Out[30]: ['G#5', 'G#5', 'G#5', 'G5', 'G#5', 'G#5', 'G#5',...

Step 6:

Convert to Music21: Now that we have our prediction_output numpy array with the predicted notes, it’s time to convert it back into a format that Music21 can recognize with the objective of converting that back to a midi file.

offset = 0
output_notes = []

# create note and chord objects based on the values generated by the model
# convert to Note objects for  music21
for pattern in prediction_output:
    if ('.' in pattern) or pattern.isdigit():  # pattern
        notes_in_chord = pattern.split('.')
        notes = []
        for current_note in notes_in_chord:
            new_note = note.Note(int(current_note))
            new_note.storedInstrument = instrument.Piano() 
            notes.append(new_note)
        new_chord = chord.Chord(notes)
        new_chord.offset = offset
        output_notes.append(new_chord)
    else:  # pattern
        new_note = note.Note(pattern)
        new_note.offset = offset
        new_note.storedInstrument = instrument.Piano()  
        output_notes.append(new_note)

    # increase offset each iteration so that notes do not stack
    offset += 0.5

#Convert to midi
midi_output = music21.stream.Stream(output_notes)
print('Saving Output file as midi....')
midi_output.write('midi', fp=' /ml/vish/audio_lstm/yjmout.midi')

Step 7:

Once we have the midi file with the generated notes, the next step was to load the midi track into Ableton. The next steps were standard recording processes one would follow to record a track in the DAW.

a) Compose and Record the Rhythm guitars, drums and Keyboards.

Instruments/software I used:

Midi
Midi

b) Insert the midi track into the DAW and quantize and sequence accordingly. This can take significant time depending on the precision wanted. In my case, this was just a quick fun project not really destined for the charts so a quick rough mix and master sufficed.

The track is on soundcloud here. The solo kicks in around the 16 second mark. Note I did have to adjust the pitch to C to blend in with the rhythm track though it was originally trained on a track in F minor

There are other ways of dealing with more sophisticated training like using different activation functions or by normalizing inputs. GRUs are another way to get past this problem and I plant iterate on more complex pieces blending deep learning with my compositions. This paper gives a great primer on the difference between LSTMs and GRUs: https://www.scihive.org/paper/1412.355

TFDV for Data validation

Working with my teams trying to build out a robust feature store these days, it’s becoming even more imperative to ensure feature Engineering data quality. The models that gain efficiency out of a performant feature store are only as good as the underlying data. 

Tensorflow Data Validation (TFDV) is a python package from the TF Extended ecosystem. The package has been around for a while but now has evolved to a point of being extremely useful for machine learning pipelines as part of feature engineering and determining data drift scenarios. Its main functionality is toĀ compute descriptiveĀ statistics,Ā infer Ā schema,andĀ detectĀ data anomalies. Ā It’s well integrated with the Google Cloud Platform and Apache Beam. The core APIĀ uses Apache Beam transforms to compute statistics over input data.

I end up using it in cases where I need quick checks on data to validate and identify drift scenarios before starting expensive training workflows. This post is a summary of some of my notes on the usage of the package. Code is here.

Data Load

TFDV accepts CSV, Dataframes or TFRecords as input.

The csv integration and the built-in visualization function makes it relatively easy to use within Jupyter notebooks. The library takes input feature data and then analyzes them by feature to visualize them. This makes it easy to get a quick understanding of the distribution of values, helps identifying anomalies and identifying training/test/validate skew. Also a great way to discover bias in the data since you can infer aggregates of values that skewed towards certain features.

As evident, with trivial amount of code you can spot issues immediately – missing columns, inconsistent distribution and data drift scenarios where newer dataset could have different statistics compared to earlier trained data.

I used a dataset from Kaggle to quickly illustrate the concept:

import tensorflow_data_validation as tfdv
train = tfdv.generate_statistics_from_csv(data_location='Data/Musical_instruments_reviews.csv', delimiter=',')
# Infer schema
schema = tfdv.infer_schema(TRAIN)
tfdv.display_schema(schema)

This generates a data structure that stores summary statistics for each feature.

TFDV Schema

Schema Inference

The schema properties describe every feature present in the 10261 reviews. Example:

  • their type (STRING)
  • Uniqueness of features – for example 1429 unique reviewer IDs.
  • the expected domains of features.
  • the min/max of the number of values for a feature in each example. For example: If A2EZWZ8MBEDOLN is a reviewerid and has 36 occurrences
top_values {
        value: "A2EZWZ8MBEDOLN"
        frequency: 36.0
      }
datasets {
  num_examples: 10261
  features {
    type: STRING
    string_stats {
      common_stats {
        num_non_missing: 10261
        min_num_values: 1
        max_num_values: 1
        avg_num_values: 1.0
        num_values_histogram {
          buckets {
            low_value: 1.0
            high_value: 1.0
            sample_count: 1026.1
          }
          buckets {
            low_value: 1.0
            high_value: 1.0
            sample_count: 1026.1
          }

Schema inference is usually tedious but becomes a breeze with TFDV. This schema is stored as a protocol buffer

schema = tfdv.infer_schema(train)
tfdv.display_schema(schema)

The schema also generates definitions like ā€œValencyā€ and ā€œPresenceā€. I could not find too much detail in the documentation but I found this useful paper that describes it well.

  • Presence: The expected presence of each feature, in terms of a minimum count and fraction of examples that must contain the feature.
  • Valency: The expected valency of the feature in each example, i.e., minimum and maximum number of values.

TFDV has inferred the revewerName as STRING and the universe of values around them termed as Domain. Note – TFDV can also encode your fields as BYTES. Im not seeing any function call in the API to update the column type as-is but you could easily update it externally if you want to explicitly specify a string. From the documentation, its explicitly advised to review the inferred schema and refine it per the requirement so as to embellish this auto-inference with our domain knowledge based on the data. You could also update the Feature based on the Data Type to BYTES, INT, FLOAT or STRUCT.

# Convert to BYTES
tfdv.get_feature(schema, 'helpfulā€™).type=1 


Once loaded, you can generate the statistics from the csv file.
For a comparison and to simulate a  dataset validation scenario, I cut down the Musical_instruments_reviews.csv to 100 rows to compare with the original and also added an extra feature called ā€˜Internal’ with the values A, B,C randomly interspersed for every row.

Visualize Statistics

After this you can pass in the ā€˜visualize_statisticsā€™ call to first visualize the two datasets based on the schema of the first dataset (TRAIN in the code). Even though this is limited to two datasets, this is a powerful way to identify issues immediately. For example – it can right off the bat identify ā€œmissing featuresā€ such as over 99.6% values in the feature. ā€œreviewerNameā€ as well as split the visualization into numerical and categorical features based on its inference of the data type.

# Load test data to compare
TEST = tfdv.generate_statistics_from_csv(data_location='Data/Musical_instruments_reviews_100.csv', delimiter=',')
# Visualize both datasets
tfdv.visualize_statistics(lhs_statistics=TRAIN, rhs_statistics=TEST, rhs_name="TEST_DATASET",lhs_name="TRAIN_DATASET")


A particularly nice option is the ability to choose a log scale for validating categorical features. The ā€˜Percentagesā€™ option can show quartile percentages.

Anomalies

Anomalies can be detected using  the display_anomalies call. The long and short descriptions allow easy visual inspection of the issues in the data. However, for large scale validation this may not be enough and you will need to   use tooling that handle a stream of defects being presented. 

# Display anomalies
anomalies = tfdv.validate_statistics(statistics=TEST, schema=schema)
tfdv.display_anomalies(anomalies)


The various kinds of anomalies that can be detected and their invocation are described here. Some especially useful ones are:

  • SCHEMA_MISSING_COLUMN
  • SCHEMA_NEW_COLUMN
  • SCHEMA_TRAINING_SERVING_SKEW
  • COMPARATOR_CONTROL_DATA_MISSING
  • COMPARATOR_TREATMENT_DATA_MISSING
  • DATASET_HIGH_NUM_EXAMPLES
  • UNKNOWN_TYPE
Anomaly

Schema Updates

Another useful feature here is the ability to update the schema and values to make corrections. For example, in order to insert a particular value

# Insert Values
names = tfdv.get_domain(schema, 'reviewerName').value
names.insert(6, "Vish") #will insert "Vish" as the 6th value of the reviewerName feature

You can also adjust the minimum number of values that must be preset in the domain and choose to drop it if is below a certain threshold.

# Relax the minimum fraction of values that must come from the domain for feature reviewerName
name = tfdv.get_feature(schema, 'reviewerName')
name.distribution_constraints.min_domain_mass = 0.9

Environments

The ability to split data into ‘Environments’ helps indicate the features that are not necessary in certain environments. For example,if we want the ā€˜internalā€™  column to be in the TEST data but not the TRAIN data. Features in schema can be associated with a set of environments using:

  •  default_environment
  •  in_environment
  •  not_in_environment
# All features are by default in both TRAINING and SERVING environments.
schema2.default_environment.append('TESTING')

# Specify that 'Internal' feature is not in SERVING environment.
tfdv.get_feature(schema2, 'Internal').not_in_environment.append('TESTING')

tfdv.validate_statistics(TEST, schema2, environment='TESTING')
#serving_anomalies_with_env

Sample anomaly output:

string_domain {
    name: "Internal"
    value: "A"
    value: "B"
    value: "C"
  }
  default_environment: "TESTING"
}
anomaly_info {
  key: "Internal"
  value {
    description: "New column Internal found in data but not in the environment TESTING in the schema."
    severity: ERROR
    short_description: "Column missing in environment"
    reason {
      type: SCHEMA_NEW_COLUMN
      short_description: "Column missing in environment"
      description: "New column Internal found in data but not in the environment TESTING in the schema."
    }
    path {
      step: "Internal"
    }
  }
}
anomaly_name_format: SERIALIZED_PATH

Skews & Drifts

The ability to detect data skews and drifts is invaluable. However, the drift  here does not indicate a divergence from the mean but refers to the ā€œL-infinityā€  norm of the difference between the summary statistics of the two datasets. We can specify a threshold which if exceeded for the given feature flags the drift. 

Lets say we have two vectors [2,3,4] and [-4,-7,8] , the L-infinity norm is the maximum absolute value of the difference between the two vectors so in this case the absolute maximum of [6,10,-4] which is 1.

#Skew comparison
tfdv.get_feature(schema,
                 'helpful').skew_comparator.infinity_norm.threshold = 0.01
skew_anomalies = tfdv.validate_statistics(statistics=TRAIN,
                                          schema=schema,
                                          serving_statistics=TEST)
skew_anomalies

Sample Output:

anomaly_info {
  key: "helpful"
  value {
    description: "The Linfty distance between training and serving is 0.187686 (up to six significant digits), above the threshold 0.01. The feature value with maximum difference is: [0, 0]"
    severity: ERROR
    short_description: "High Linfty distance between training and serving"
    reason {
      type: COMPARATOR_L_INFTY_HIGH
      short_description: "High Linfty distance between training and serving"
      description: "The Linfty distance between training and serving is 0.187686 (up to six significant digits), above the threshold 0.01. The feature value with maximum difference is: [0, 0]"
    }
    path {
      step: "helpful"
    }
  }
}
anomaly_name_format: SERIALIZED_PATH

The drift comparator is useful in cases where you could have the same data being loaded in a frequent basis and you need to watch for anomalies to reengineer features. The validate_statistics call combined with the drift_comparator threshold can be used to monitor for any changes that you need to action on.

#Drift comparator
tfdv.get_feature(schema,'helpful').drift_comparator.infinity_norm.threshold = 0.01
drift_anomalies = tfdv.validate_statistics(statistics=TRAIN,schema=schema,previous_statistics=TRAIN)
drift_anomalies
Anomaly_info {
  key: "reviewerName"
  value {
    description: "The feature was present in fewer examples than expected."
    severity: ERROR
    short_description: "Column dropped"
    reason {
      type: FEATURE_TYPE_LOW_FRACTION_PRESENT
      short_description: "Column dropped"
      description: "The feature was present in fewer examples than expected."
    }
    path {
      step: "reviewerName"
    }
  }
}

You can easily save the updated schema in the format you want for further processing.

Overall, this has been useful to me to use for mainly models within the TensorFlow ecosystem and the documentation indicates that using components like StatisticsGen with TFX makes this a breeze to use in pipelines with out-of-the box integration on a platform like GCP.

The use case for avoiding time-consuming preprocessing/training steps by using TFDV to identify anomalies for feature drift and inference decay is a no-brainer however defect handling is up to the developer to incorporate. It’s important to also consider that ones domain knowledge on the data plays a huge role in these scenarios for optimizing data according to your needs so an auto-fix on all data anomalies may not really work in cases where a careful review is unavoidable.

This can also be extended for overall general data quality by applying to any validation cases where you are constantly getting updated data for the features. The application of TFDV could even be post-training for any data input/output scenario to ensure that values are as expected.


Official documentation is here.