Mail.ruПочтаМой МирОдноклассникиВКонтактеИгрыЗнакомстваНовостиКалендарьОблакоЗаметкиВсе проекты

Как доработать эту нейросеть

Guchi guchi ga ga ga Ученик (117), открыт 1 неделю назад
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt

Загрузка набора данных Oxford IIIT Pet
(raw_train, raw_validation, raw_test), metadata = tfds.load(
'oxford_iiit_pet:3..',
split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],
with_info=True,
as_supervised=True,
)

Предварительная обработка изображений
IMG_SIZE = 160

def format_example(image, label):
image = tf.cast(image, tf.float32)
image = (image/127.5) - 1
image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
return image, label

train = raw_ train.map (format_example)
validation = raw_ validation.map (format_example)
test = raw_ test.map (format_example)

Батчи данных
BATCH_SIZE = 32
SHUFFLE_BUFFER_SIZE = 1000

train_batches = train.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
validation_batches = validation.batch(BATCH_SIZE)
test_batches = test.batch(BATCH_SIZE)

Модель MobileNetV2 для transfer learning
IMG_SHAPE = (IMG_SIZE, IMG_SIZE, 3)
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
include_top=False,
weights='imagenet')

base_model.trainable = False

Добавляем сверху классификатор
global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
prediction_layer = tf.keras.layers.Dense(metadata.features['label'].num_classes)

model = tf.keras.Sequential([
base_model,
global_average_layer,
prediction_layer
])

Компиляция модели
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])

Обучение модели
history = model.fit (train_batches,
epochs=10,
validation_data=validation_batches)

Предсказание для изображения по индексу
def predict_image(index):
for i, (image, label) in enumerate(test):
if i == index:
image = tf.expand_dims(image, 0)
prediction = model.predict(image)
predicted_label = tf.argmax(prediction, -1).numpy()[0]

print('Predicted label: ', metadata.features['label'].int2str(predicted_label))
print('True label: ', metadata.features['label'].int2str(label.numpy()))

# Добавление визуализации изображения
plt.figure()
plt.imshow((image[0] + 1) / 2)
plt.title('True label: ' + metadata.features['label'].int2str(label.numpy()))
plt.show()
break
Ввод индекса изображения и предсказание
image_index = int(input("Please enter the image index: "))
predict_image(image_index)
2 ответа
Похожие вопросы