Skip to content

feat(firebaseai): Add image editing and upscaling #17410

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
329 changes: 290 additions & 39 deletions packages/firebase_ai/firebase_ai/example/lib/pages/imagen_page.dart
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.

import 'package:flutter/material.dart';
import 'dart:typed_data';

import 'package:image_picker/image_picker.dart';
import 'package:firebase_ai/firebase_ai.dart';

import 'package:flutter/material.dart';
//import 'package:firebase_storage/firebase_storage.dart';
import '../widgets/message_widget.dart';

Expand All @@ -38,6 +42,10 @@ class _ImagenPageState extends State<ImagenPage> {
final List<MessageData> _generatedContent = <MessageData>[];
bool _loading = false;

// For image picking
ImagenInlineImage? _sourceImage;
ImagenInlineImage? _maskImageForEditing;

void _scrollDown() {
WidgetsBinding.instance.addPostFrameCallback(
(_) => _scrollController.animateTo(
Expand Down Expand Up @@ -80,45 +88,89 @@ class _ImagenPageState extends State<ImagenPage> {
vertical: 25,
horizontal: 15,
),
child: Row(
child: Column(
children: [
Expanded(
child: TextField(
autofocus: true,
focusNode: _textFieldFocus,
controller: _textController,
),
),
const SizedBox.square(
dimension: 15,
),
if (!_loading)
IconButton(
onPressed: () async {
await _testImagen(_textController.text);
},
icon: Icon(
Icons.image_search,
color: Theme.of(context).colorScheme.primary,
// Generate Image Row
Row(
children: [
Expanded(
child: TextField(
autofocus: true,
focusNode: _textFieldFocus,
decoration: const InputDecoration(
hintText: 'Enter a prompt...',
),
controller: _textController,
),
),
tooltip: 'Imagen raw data',
)
else
const CircularProgressIndicator(),
// NOTE: Keep this API private until future release.
// if (!_loading)
// IconButton(
// onPressed: () async {
// await _testImagenGCS(_textController.text);
// },
// icon: Icon(
// Icons.imagesearch_roller,
// color: Theme.of(context).colorScheme.primary,
// ),
// tooltip: 'Imagen GCS',
// )
// else
// const CircularProgressIndicator(),
const SizedBox.square(dimension: 15),
IconButton(
onPressed: () async {
await _pickSourceImage();
},
icon: Icon(
Icons.add_a_photo,
color: Theme.of(context).colorScheme.primary,
),
tooltip: 'Pick Source Image',
),
IconButton(
onPressed: () async {
await _pickMaskImage();
},
icon: Icon(
Icons.add_to_photos,
color: Theme.of(context).colorScheme.primary,
),
tooltip: 'Pick mask',
),
IconButton(
onPressed: () async {
await _editImageMaskFree();
},
icon: Icon(
Icons.edit,
color: Theme.of(context).colorScheme.primary,
),
tooltip: 'Edit Image Mask Free',
),
IconButton(
onPressed: () async {
await _editImageInpaintOutpaint();
},
icon: Icon(
Icons.masks,
color: Theme.of(context).colorScheme.primary,
),
tooltip: 'Mask Inpaint Outpaint',
),
IconButton(
onPressed: () async {
await _upscaleImage();
},
icon: Icon(
Icons.plus_one,
color: Theme.of(context).colorScheme.primary,
),
tooltip: 'Upscale',
),
if (!_loading)
IconButton(
onPressed: () async {
await _generateImageFromPrompt(
_textController.text,
);
},
icon: Icon(
Icons.image_search,
color: Theme.of(context).colorScheme.primary,
),
tooltip: 'Generate Image',
)
else
const CircularProgressIndicator(),
],
),
],
),
),
Expand All @@ -128,7 +180,206 @@ class _ImagenPageState extends State<ImagenPage> {
);
}

Future<void> _testImagen(String prompt) async {
Future<ImagenInlineImage?> _pickImage() async {
final ImagePicker picker = ImagePicker();
try {
final XFile? imageFile =
await picker.pickImage(source: ImageSource.gallery);
if (imageFile != null) {
// Attempt to get mimeType, default if null.
// Note: imageFile.mimeType might be null on some platforms or for some files.
final String mimeType = imageFile.mimeType ?? 'image/jpeg';
final Uint8List imageBytes = await imageFile.readAsBytes();
return ImagenInlineImage(
bytesBase64Encoded: imageBytes, mimeType: mimeType);
}
} catch (e) {
_showError('Error picking image: $e');
}
return null;
}

Future<void> _pickSourceImage() async {
final pickedImage = await _pickImage();
if (pickedImage != null) {
setState(() {
_sourceImage = pickedImage;
});
}
}

Future<void> _pickMaskImage() async {
final pickedImage = await _pickImage();
if (pickedImage != null) {
setState(() {
_maskImageForEditing = pickedImage;
});
}
}

Future<void> _upscaleImage() async {
if (_sourceImage == null) {
_showError('Please pick a source image for upscaling.');
return;
}
setState(() {
_loading = true;
});

setState(() {
_generatedContent.add(
MessageData(
image: Image.memory(_sourceImage!.bytesBase64Encoded),
text:
'Try to Upscaled image (Factor: ${ImagenUpscaleFactor.x2.name})',
fromUser: true,
),
);
_scrollDown();
});

try {
final response = await widget.model.upscaleImage(
image: _sourceImage!,
upscaleFactor: ImagenUpscaleFactor.x2,
);
if (response.images.isNotEmpty) {
final upscaledImage = response.images[0];
setState(() {
_generatedContent.add(
MessageData(
image: Image.memory(upscaledImage.bytesBase64Encoded),
text: 'Upscaled image (Factor: ${ImagenUpscaleFactor.x2.name})',
fromUser: false,
),
);
_scrollDown();
});
} else {
_showError('No image was returned from upscaling.');
}
} catch (e) {
_showError('Error upscaling image: $e');
}

setState(() {
_loading = false;
});
}

Future<void> _editImageInpaintOutpaint() async {
if (_sourceImage == null || _maskImageForEditing == null) {
_showError(
'Please pick a source image and a mask image for inpainting/outpainting.');
return;
}
setState(() {
_loading = true;
});

final String prompt = _textController.text;

setState(() {
_generatedContent.add(
MessageData(
image: Image.memory(_sourceImage!.bytesBase64Encoded),
text: prompt,
fromUser: true,
),
);
_scrollDown();
});

final editConfig = ImagenEditingConfig(
image: _sourceImage!,
mask: _maskImageForEditing,
maskDilation: 0.01,
editSteps: 50,
);

try {
final response = await widget.model.editImage(
prompt,
config: editConfig,
);
if (response.images.isNotEmpty) {
final editedImage = response.images[0];
setState(() {
_generatedContent.add(
MessageData(
image: Image.memory(editedImage.bytesBase64Encoded),
text: 'Edited image (Inpaint/Outpaint): $prompt',
fromUser: false,
),
);
_scrollDown();
});
} else {
_showError('No image was returned from editing.');
}
} catch (e) {
_showError('Error editing image: $e');
}
setState(() {
_loading = false;
});
}

Future<void> _editImageMaskFree() async {
if (_sourceImage == null) {
_showError('Please pick a source image for mask-free editing.');
return;
}
setState(() {
_loading = true;
});

final String prompt = _textController.text;

setState(() {
_generatedContent.add(
MessageData(
image: Image.memory(_sourceImage!.bytesBase64Encoded),
text: prompt,
fromUser: true,
),
);
_scrollDown();
});
final editConfig = ImagenEditingConfig.maskFree(
image: _sourceImage!,
// numberOfImages: 1, // Default in model or could be added to UI
);

try {
final response = await widget.model.editImage(
prompt,
config: editConfig,
);
if (response.images.isNotEmpty) {
final editedImage = response.images[0];
setState(() {
_generatedContent.add(
MessageData(
image: Image.memory(editedImage.bytesBase64Encoded),
text: 'Edited image (Mask-Free): $prompt',
fromUser: false,
),
);
_scrollDown();
});
} else {
_showError('No image was returned from mask-free editing.');
}
} catch (e) {
_showError('Error performing mask-free edit: $e');
}
setState(() {
_loading = false;
});
}

Future<void> _generateImageFromPrompt(String prompt) async {
setState(() {
_loading = true;
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,7 @@
<true/>
<key>com.apple.security.device.audio-input</key>
<true/>
<key>com.apple.security.files.user-selected.read-only</key>
<true/>
</dict>
</plist>
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,7 @@
<string>NSApplication</string>
<key>NSMicrophoneUsageDescription</key>
<string>Permission to Record audio</string>
<key>NSPhotoLibraryUsageDescription</key>
<string>This app needs access to your photo library to let you select a profile picture.</string>
</dict>
</plist>
1 change: 1 addition & 0 deletions packages/firebase_ai/firebase_ai/example/pubspec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ dependencies:
sdk: flutter
flutter_markdown: ^0.6.20
flutter_soloud: ^3.1.6
image_picker: ^1.1.2
path_provider: ^2.1.5
record: ^5.2.1

Expand Down
7 changes: 5 additions & 2 deletions packages/firebase_ai/firebase_ai/lib/firebase_ai.dart
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,11 @@ export 'src/imagen_api.dart'
ImagenSafetyFilterLevel,
ImagenPersonFilterLevel,
ImagenGenerationConfig,
ImagenAspectRatio;
export 'src/imagen_content.dart' show ImagenInlineImage;
ImagenAspectRatio,
ImagenEditingConfig,
ImagenEditMode,
ImagenUpscaleFactor;
export 'src/imagen_content.dart' show ImagenInlineImage, ImagenGenerationResponse;
export 'src/live_api.dart'
show
LiveGenerationConfig,
Expand Down
Loading
Loading