Module ieat.utils
Expand source code Browse git
# Some code adapted from
# https://colab.research.google.com/github/apeguero1/image-gpt/blob/master/Transformers_Image_GPT.ipynb
# - thanks to the author
import cv2
import numpy as np
from collections import namedtuple
# numpy implementation of functions in image-gpt/src/utils which convert pixels of image to nearest color cluster.
# Resize original images to n_px by n_px
def normalize_img(img):
return img/127.5 - 1
def squared_euclidean_distance_np(a,b):
b = b.T
a2 = np.sum(np.square(a),axis=1)
b2 = np.sum(np.square(b),axis=0)
ab = np.matmul(a,b)
d = a2[:,None] - 2*ab + b2[None,:]
return d
def color_quantize_np(x, clusters):
x = x.reshape(-1, 3)
d = squared_euclidean_distance_np(x, clusters)
return np.argmin(d,axis=1)
def resize(n_px, image_paths, rotate_90=False):
dim=(n_px,n_px)
x = np.zeros((len(image_paths),n_px,n_px,3),dtype=np.uint8)
for n,image_path in enumerate(image_paths):
img_np = cv2.imread(image_path) # reads an image in the BGR format
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB) # BGR -> RGB
H,W,C = img_np.shape
D = min(H,W)
img_np = img_np[:D,:D,:C] #get square piece of image
if (rotate_90):
img_np = cv2.rotate(img_np, cv2.cv2.ROTATE_90_CLOCKWISE)
x[n] = cv2.resize(img_np,dim, interpolation = cv2.INTER_AREA) #resize to n_px by n_px
return x
TestData = namedtuple('TestData', ['name', 'X', 'Y', 'A', 'B'])
tests_all = [
# Baseline
TestData(
'Insect-Flower', 'insect-flower/flower', 'insect-flower/insect', 'valence/pleasant', 'valence/unpleasant'
),
# Picture-Picture IATS
TestData('Weapon', 'weapon/white', 'weapon/black', 'weapon/tool', 'weapon/weapon'),
TestData('Weapon (Modern)', 'weapon/white', 'weapon/black', 'weapon/tool-modern', 'weapon/weapon-modern'),
TestData('Native', 'native/euro', 'native/native', 'native/us', 'native/world'),
TestData('Asian', 'asian/european-american', 'asian/asian-american', 'asian/american', 'asian/foreign'),
# Valence IATs
TestData('Weight', 'weight/thin', 'weight/fat', 'valence/pleasant', 'valence/unpleasant'),
TestData('Skin-Tone', 'skin-tone/light', 'skin-tone/dark', 'valence/pleasant', 'valence/unpleasant'),
TestData('Disability', 'disabled/disabled', 'disabled/abled', 'valence/pleasant', 'valence/unpleasant'),
TestData(
'President - Kennedy vs. Trump',
'presidents/kennedy', 'presidents/trump', 'valence/pleasant', 'valence/unpleasant'
),
TestData(
'President - B. Clinton vs. Trump',
'presidents/clinton', 'presidents/trump', 'valence/pleasant', 'valence/unpleasant'
),
TestData(
'President - Bush vs. Trump',
'presidents/bush', 'presidents/trump', 'valence/pleasant', 'valence/unpleasant'
),
TestData(
'President - Lincoln vs. Trump',
'presidents/lincoln', 'presidents/trump', 'valence/pleasant', 'valence/unpleasant'
),
TestData('Religion', 'religion/christianity', 'religion/judaism', 'valence/pleasant', 'valence/unpleasant'),
TestData('Sexuality', 'sexuality/gay', 'sexuality/straight', 'valence/pleasant', 'valence/unpleasant'),
TestData('Race', 'race/european-american', 'race/african-american', 'valence/pleasant', 'valence/unpleasant'),
TestData(
'Arab-Muslim',
'arab-muslim/other-people', 'arab-muslim/arab-muslim', 'valence/pleasant', 'valence/unpleasant'
),
TestData('Age', 'age/young', 'age/old', 'valence/pleasant', 'valence/unpleasant'),
# Stereotype IATS
TestData('Gender-Science', 'gender/male', 'gender/female', 'gender/science', 'gender/liberal-arts'),
TestData('Gender-Career', 'gender/male', 'gender/female', 'gender/career', 'gender/family'),
# Intersectional IATs
# - Gender Stereotypes
TestData(
'Intersectional-Gender-Science-MF', 'intersectional/male',
'intersectional/female', 'gender/science', 'gender/liberal-arts'
),
TestData(
'Intersectional-Gender-Science-WMBM', 'intersectional/white-male',
'intersectional/black-male', 'gender/science', 'gender/liberal-arts'
),
TestData(
'Intersectional-Gender-Science-WMBF', 'intersectional/white-male',
'intersectional/black-female', 'gender/science', 'gender/liberal-arts'
),
TestData(
'Intersectional-Gender-Science-WMWF', 'intersectional/white-male',
'intersectional/white-female', 'gender/science', 'gender/liberal-arts'
),
TestData(
'Intersectional-Gender-Science-BMBF', 'intersectional/black-male',
'intersectional/black-female', 'gender/science', 'gender/liberal-arts'
),
TestData(
'Intersectional-Gender-Science-BMWF', 'intersectional/black-male',
'intersectional/white-female', 'gender/science', 'gender/liberal-arts'
),
TestData(
'Intersectional-Gender-Career-MF', 'intersectional/male',
'intersectional/female', 'gender/career', 'gender/family'
),
TestData(
'Intersectional-Gender-Career-WMBM', 'intersectional/black-male',
'intersectional/white-male', 'gender/career', 'gender/family'
),
TestData(
'Intersectional-Gender-Career-WMBF', 'intersectional/white-male',
'intersectional/black-female', 'gender/career', 'gender/family'
),
TestData(
'Intersectional-Gender-Career-WMWF', 'intersectional/white-male',
'intersectional/white-female', 'gender/career', 'gender/family'
),
TestData(
'Intersectional-Gender-Career-BMBF', 'intersectional/black-male',
'intersectional/black-female', 'gender/career', 'gender/family'
),
TestData(
'Intersectional-Gender-Career-BMWF', 'intersectional/black-male',
'intersectional/white-female', 'gender/career', 'gender/family'
),
# - Valence
TestData(
'Intersectional-Valence-BW', 'intersectional/white', 'intersectional/black', 'valence/pleasant',
'valence/unpleasant'
),
TestData(
'Intersectional-Valence-WMBM', 'intersectional/white-male', 'intersectional/black-male', 'valence/pleasant',
'valence/unpleasant'
),
TestData(
'Intersectional-Valence-WMBF', 'intersectional/white-male', 'intersectional/black-female', 'valence/pleasant',
'valence/unpleasant'
),
TestData(
'Intersectional-Valence-WMWF', 'intersectional/white-female', 'intersectional/white-male', 'valence/pleasant',
'valence/unpleasant'
),
TestData(
'Intersectional-Valence-WFBM', 'intersectional/white-female', 'intersectional/black-male', 'valence/pleasant',
'valence/unpleasant'
),
TestData(
'Intersectional-Valence-BFBM', 'intersectional/black-female', 'intersectional/black-male', 'valence/pleasant',
'valence/unpleasant'
),
TestData(
'Intersectional-Valence-WFBF', 'intersectional/white-female', 'intersectional/black-female', 'valence/pleasant',
'valence/unpleasant'
),
TestData(
'Intersectional-Valence-FM', 'intersectional/female', 'intersectional/male', 'valence/pleasant',
'valence/unpleasant'
)
]
Functions
def color_quantize_np(x, clusters)
-
Expand source code Browse git
def color_quantize_np(x, clusters): x = x.reshape(-1, 3) d = squared_euclidean_distance_np(x, clusters) return np.argmin(d,axis=1)
def normalize_img(img)
-
Expand source code Browse git
def normalize_img(img): return img/127.5 - 1
def resize(n_px, image_paths, rotate_90=False)
-
Expand source code Browse git
def resize(n_px, image_paths, rotate_90=False): dim=(n_px,n_px) x = np.zeros((len(image_paths),n_px,n_px,3),dtype=np.uint8) for n,image_path in enumerate(image_paths): img_np = cv2.imread(image_path) # reads an image in the BGR format img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB) # BGR -> RGB H,W,C = img_np.shape D = min(H,W) img_np = img_np[:D,:D,:C] #get square piece of image if (rotate_90): img_np = cv2.rotate(img_np, cv2.cv2.ROTATE_90_CLOCKWISE) x[n] = cv2.resize(img_np,dim, interpolation = cv2.INTER_AREA) #resize to n_px by n_px return x
def squared_euclidean_distance_np(a, b)
-
Expand source code Browse git
def squared_euclidean_distance_np(a,b): b = b.T a2 = np.sum(np.square(a),axis=1) b2 = np.sum(np.square(b),axis=0) ab = np.matmul(a,b) d = a2[:,None] - 2*ab + b2[None,:] return d
Classes
class TestData (name, X, Y, A, B)
-
TestData(name, X, Y, A, B)
Ancestors
- builtins.tuple
Instance variables
var A
-
Alias for field number 3
var B
-
Alias for field number 4
var X
-
Alias for field number 1
var Y
-
Alias for field number 2
var name
-
Alias for field number 0