Building a board game using JAX, TFLite and Flutter

Wayne Wei
5 min readApr 16, 2021

--

Let’s see how we can use JAX, TFLite and Flutter, 3 cool Google open source products in a single project 😃

We know that reinforcement learning has been widely used by gaming companies to create gaming bots in live gaming or offline quality assurance, testing game balance, and designing game characters.

So today we are going to build a simple board game app to illustrate how this could be done. The finished app looks like below:

The code has been open sourced on GitHub. For game rule introduction, please refer to that repo (in short, it’s very similar to Battleship game). In this post, we are going to focus on how to build such a game.

Before we discuss the actual implementations, let me say a few words about the tools we are using.

  • JAX. JAX is a machine learning framework development by Google Research team. It is ‘numpy on steroid’ as some people say. We are using one of the high-level frameworks Flax built on top of JAX. Technically we could use TensorFlow to do this, but I wanted to learn more about JAX ecosystem.
  • TFLite. It’s the de facto mobile inference framework widely used in the industry.
  • Flutter. It is the frontend framework by Google that is gaining huge traction these days.

Since you are playing against the agent, we need to train an agent that can guess out all the ‘plane cells’ fast. You can probably write manual rules to do it, but that would not be fun. Instead, we are going to train a neural network using policy gradient reinforcement learning (REINFORCE).

To implement REINFORCE, we first randomly initialize a simple neural network with 3 dense layers. With Flax, defining the neural network is easy:

class PolicyGradient(nn.Module):    
@nn.compact
def __call__(self, x):
dtype = jnp.float32
x = x.reshape((x.shape[0], -1))
x = nn.Dense(features=2*BOARD_SIZE**2, name='hidden1', dtype=dtype)(x)
x = nn.relu(x)
x = nn.Dense(features=BOARD_SIZE**2, name='hidden2', dtype=dtype)(x)
x = nn.relu(x)
x = nn.Dense(features=BOARD_SIZE**2, name='logits', dtype=dtype)(x)
policy_probabilities = nn.softmax(x)
return policy_probabilities

Then we use that neural network to play one round of the game ’til the end (“rollout”). It doesn’t matter how good initially the neural network is; as long as we can gather the log of different board positions, the strike positions and hit/miss signals, then we can feed all these into the training step so that we can compute the gradients and make the neural network learn what good/bad strikes are. We also do some reward shaping to make sure learning is more effective.

The key equation for REINFORCE is:

Source: https://spinningup.openai.com/en/latest/algorithms/vpg.html

This translates to the following code with JAX/Flax:

def compute_loss(logits, labels, rewards): 
one_hot_labels = jax.nn.one_hot(labels, num_classes=BOARD_SIZE**2)
loss = -jnp.mean(jnp.sum(one_hot_labels * jnp.log(logits), axis=-1) * jnp.asarray(rewards))
return loss
@jax.jit
def train_iteration(optimizer, board_pos_log, action_log, reward_log):
def loss_fn(params):
logits = PolicyGradient().apply({'params': params}, board_pos_log)
loss = compute_loss(logits, action_log, reward_log)
return loss
grad_fn = jax.grad(loss_fn)
grads = grad_fn(optimizer.target)
optimizer = optimizer.apply_gradient(grads)
return optimizer

Note the “jnp.asarray(rewards)” part in compute_loss() method; it corresponds to the A (advantage) in the equation. This is the part that is different from the cross-entropy loss. For simplicity, I didn’t implement advantages but instead just used the reshaped rewards.

Then we use jax.grad() to compute the gradients and apply_gradient() method to apply the gradients. This is similar to TensorFlow custom training loop.

One aspect in which reinforcement learning diffs from supervised learning is that for RL sometimes it’s possible to visualize training progress with loss as supervised learning, since loss does not directly measure how good your algorithm is. So instead we are going to measure the game length as a proxy. If the agent is really good, then it should be able to finish the game in less steps. The shorter the game is, the stronger the agent is.

As you can see, after ~250K iterations (game plays), our agent has converged. The game length averages about 13. Basically the agent cannot make more than 5 wrong guesses on average. You can try out other RL algorithms like DQN/PPO (Flax actually has a PPO example) and I think you should get similar performance.

After the training, we can convert the JAX/Flax model to TFLite model using the experimental jax2tf conversion tool.

# Convert to tflite model
model = PolicyGradient()
predict_fn = lambda input: model.apply({"params": params}, input)
tf_predict = tf.function(
jax2tf.convert(predict_fn, enable_xla=False),
input_signature=[
tf.TensorSpec(shape=[1, BOARD_SIZE, BOARD_SIZE],
dtype=tf.float32, n
ame='input')],
autograph=False)
converter = tf.lite.TFLiteConverter.from_concrete_functions(
[tf_predict.get_concrete_function()]
)
tflite_model = converter.convert()
with open('planestrike.tflite', 'wb') as f:
f.write(tflite_model)

Technically you can convert the trained model to SavedModel and then it will be able to consumed by TF Serving, TFLite, TFJS and TFHub.

Once we have the TFLite model, we can deploy it into our app. Since there is no official Flutter plugin for TFLite, we can use this tflite_flutter_plugin from the community. Flutter and the plugin are really easy and fun to use. I was able to build out the app frontend in 3 days from scratch (I had no prior experience in Flutter). And it worked seamlessly on Android/iOS.

The frontend is all about manipulating the cell colors based on the board states. And running TFLite model inference take 1 single line (_interpreter.run(input, output)):

int predict(List<List<double>> boardState) {
var input = [boardState];
var output = List.filled(_boardSize * _boardSize, 0)
.reshape([1, _boardSize * _boardSize]);
// Run inference
_interpreter.run(input, output);
// Argmax
double max = output[0][0];
int maxIdx = 0;
for (int i = 1; i < _boardSize * _boardSize; i++) {
if (max < output[0][i]) {
maxIdx = i;
max = output[0][i];
}
}
return maxIdx;
}

So this is about it. We were able to use JAX/TFLite/Flutter to build a simple yet functioning board game for both iOS/Android. All 3 projects are pretty awesome in their own domains.

Of course, there are many places that can be optimized (especially the app UI). But for demonstration purposes, I think this is a good start to help some people get started.

The code has been published on GitHub. Feel free to check out.

Reference:

The original idea and part of the code come from this post:

https://www.efavdb.com/battleship

--

--

No responses yet