Skip to content

C1_W2_Assignment.js #46

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 49 additions & 44 deletions C1_Browser-based-TF-JS/W2/assignment/C1_W2_Assignment.js
Original file line number Diff line number Diff line change
@@ -1,61 +1,70 @@
import {FMnistData} from './fashion-data.js';

var canvas, ctx, saveButton, clearButton;
var pos = {x:0, y:0};
var rawImage;
var model;

function getModel() {

// In the space below create a convolutional neural network that can classify the
// images of articles of clothing in the Fashion MNIST dataset. Your convolutional
// neural network should only use the following layers: conv2d, maxPooling2d,
// flatten, and dense. Since the Fashion MNIST has 10 classes, your output layer
// should have 10 units and a softmax activation function. You are free to use as
// many layers, filters, and neurons as you like.
// HINT: Take a look at the MNIST example.
model = tf.sequential();

// YOUR CODE HERE


// Compile the model using the categoricalCrossentropy loss,
// the tf.train.adam() optimizer, and `acc` for your metrics.
model.compile(// YOUR CODE HERE);

model.add(tf.layers.conv2d({
inputShape: [28, 28, 1],
kernelSize: 3,
filters: 16,
activation: 'relu'
}));

model.add(tf.layers.maxPooling2d({
poolSize: [2, 2]
}));

model.add(tf.layers.conv2d({
kernelSize: 3,
filters: 32,
activation: 'relu'
}));

model.add(tf.layers.flatten());

model.add(tf.layers.dense({
units: 128,
activation: 'relu'
}));

model.add(tf.layers.dense({
units: 10,
activation: 'softmax'
}));

model.compile({
loss: 'categoricalCrossentropy',
optimizer: tf.train.adam(),
metrics: ['acc']
});

return model;
}

async function train(model, data) {

// Set the following metrics for the callback: 'loss', 'val_loss', 'acc', 'val_acc'.
const metrics = // YOUR CODE HERE


// Create the container for the callback. Set the name to 'Model Training' and
// use a height of 1000px for the styles.
const container = // YOUR CODE HERE


// Use tfvis.show.fitCallbacks() to setup the callbacks.
// Use the container and metrics defined above as the parameters.
const fitCallbacks = // YOUR CODE HERE

const metrics = ['loss', 'val_loss', 'acc', 'val_acc'];
const container = document.getElementById('training');
const fitCallbacks = tfvis.show.fitCallbacks(container, metrics);

const BATCH_SIZE = 512;
const TRAIN_DATA_SIZE = 6000;
const TEST_DATA_SIZE = 1000;

// Get the training batches and resize them. Remember to put your code
// inside a tf.tidy() clause to clean up all the intermediate tensors.
// HINT: Take a look at the MNIST example.
const [trainXs, trainYs] = // YOUR CODE HERE


// Get the testing batches and resize them. Remember to put your code
// inside a tf.tidy() clause to clean up all the intermediate tensors.
// HINT: Take a look at the MNIST example.
const [testXs, testYs] = // YOUR CODE HERE
const [trainXs, trainYs] = tf.tidy(() => {
const batch = data.nextTrainBatch(TRAIN_DATA_SIZE);
return [batch.xs.reshape([TRAIN_DATA_SIZE, 28, 28, 1]), batch.labels];
});

const [testXs, testYs] = tf.tidy(() => {
const batch = data.nextTestBatch(TEST_DATA_SIZE);
return [batch.xs.reshape([TEST_DATA_SIZE, 28, 28, 1]), batch.labels];
});


return model.fit(trainXs, trainYs, {
batchSize: BATCH_SIZE,
validationData: [testXs, testYs],
Expand Down Expand Up @@ -100,7 +109,6 @@ function save() {
"Dress", "Coat", "Sandal", "Shirt",
"Sneaker", "Bag", "Ankle boot"];


alert(classNames[pIndex]);
}

Expand Down Expand Up @@ -132,6 +140,3 @@ async function run() {
}

document.addEventListener('DOMContentLoaded', run);